1//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
10#include "mlir/Analysis/SliceAnalysis.h"
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
13#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/Transform/IR/TransformTypes.h"
16#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21
22using namespace mlir;
23
24#define DEBUG_TYPE "linalg-transforms"
25#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26
27//===----------------------------------------------------------------------===//
28// StructuredMatchOp
29//===----------------------------------------------------------------------===//
30
31DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32 Operation *current, transform::TransformResults &results,
33 transform::TransformState &state) {
34 // First, check if the payload operation is a structured Linalg operation.
35 if (!isa<linalg::LinalgOp>(current)) {
36 if (getFailurePropagationMode().value_or(
37 FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() << "expected a Linalg op";
40 }
41 // If errors are suppressed, succeed and set all results to empty lists.
42 LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
44 return DiagnosedSilenceableFailure::success();
45 }
46
47 // Bind `current` to the block argument.
48 auto scope = state.make_region_scope(getBodyRegion());
49 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50 MappedValue(current)))) {
51 return DiagnosedSilenceableFailure::definiteFailure();
52 }
53
54 for (Operation &nested : getBody()->without_terminator()) {
55 DiagnosedSilenceableFailure diag =
56 state.applyTransform(cast<TransformOpInterface>(nested));
57 if (diag.isDefiniteFailure())
58 return diag;
59 if (diag.succeeded())
60 continue;
61
62 // If propagating errors, do this immediately.
63 assert(diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
67 return diag;
68 }
69
70 // If suppressing errors, print the message into the debug stream before
71 // silencing it. Then set all results value that are already known.
72 // Results come from the terminator operands, which may be defined in the
73 // (single) block of this operation or above it. When they are defined
74 // above, they are known to be mapped at this point per SSA dominance.
75 // When they are defined in this block, we additionally check if we have
76 // already applied the operation that defines them. If not, the
77 // corresponding results will be set to empty lists.
78 LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79 << "\n");
80 (void)diag.silence();
81 SmallVector<OpOperand *> undefinedOperands;
82 for (OpOperand &terminatorOperand :
83 getBody()->getTerminator()->getOpOperands()) {
84 Operation *definingOp = terminatorOperand.get().getDefiningOp();
85 if (!definingOp)
86 continue;
87 if (definingOp->getBlock() != getBody())
88 continue;
89 if (definingOp->isBeforeInBlock(&nested))
90 continue;
91
92 undefinedOperands.push_back(&terminatorOperand);
93 }
94
95 SmallVector<SmallVector<transform::MappedValue>> mappings;
96 auto filtered = llvm::make_filter_range(
97 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
98 return !llvm::is_contained(undefinedOperands, &opOperand);
99 });
100 SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
101 filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
102 detail::prepareValueMappings(mappings, definedOperands, state);
103 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
104 results.setMappedValues(getResults()[operand.getOperandNumber()],
105 mapping);
106 }
107 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
108 return DiagnosedSilenceableFailure::success();
109 }
110
111 // Set the results.
112 detail::forwardTerminatorOperands(getBody(), state, results);
113 return DiagnosedSilenceableFailure::success();
114}
115
116void transform::MatchStructuredOp::getEffects(
117 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
118 onlyReadsHandle(getCurrent(), effects);
119 onlyReadsPayload(effects);
120 producesHandle(getOutputs(), effects);
121}
122
123LogicalResult transform::MatchStructuredOp::verify() {
124 if (getBody()->getNumArguments() != 1)
125 return emitOpError() << "expected one body argument";
126 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
127 return emitOpError() << "expected body argument to implement "
128 "TransformHandleTypeInterface";
129 }
130 for (Operation &nested : getBody()->without_terminator()) {
131 if (isa<MatchOpInterface>(nested))
132 continue;
133 InFlightDiagnostic diag =
134 emitOpError()
135 << "expects nested operations to implement MatchOpInterface";
136 diag.attachNote(nested.getLoc()) << "offending operation";
137 return diag;
138 }
139 return success();
140}
141
142//===----------------------------------------------------------------------===//
143// StructuredOpPredicateOpTrait
144//===----------------------------------------------------------------------===//
145
146LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
147 Operation *op, Value structuredOpHandle) {
148 if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
149 return op->emitOpError() << "expects parent op to be '"
150 << MatchStructuredOp::getOperationName() << "'";
151 }
152
153 // Bail out here, let the verifier of the parent complain.
154 Operation *parent = op->getParentOp();
155 if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() ||
156 parent->getRegion(index: 0).front().getNumArguments() < 1)
157 return success();
158
159 if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) {
160 return op->emitOpError()
161 << "expected predicate to apply to the surrounding structured op";
162 }
163 return success();
164}
165
166//===----------------------------------------------------------------------===//
167// MatchStructuredBodyOp
168//===----------------------------------------------------------------------===//
169
170DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171 Operation *current, transform::TransformResults &results,
172 transform::TransformState &state) {
173 auto linalgOp = cast<linalg::LinalgOp>(current);
174 if (std::optional<uint64_t> position = getReductionPosition()) {
175 SmallVector<Operation *> combinerOps;
176 if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
177 combinerOps)) {
178 return emitSilenceableError() << "could not match reduction";
179 }
180 if (combinerOps.size() != 1) {
181 return emitSilenceableError() << "reduction combiner is not a single op";
182 }
183 return DiagnosedSilenceableFailure::success();
184 }
185 if (getPassthrough()) {
186 Block &body = linalgOp->getRegion(0).front();
187 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188 return emitSilenceableError() << "not a passthrough";
189 }
190 return DiagnosedSilenceableFailure::success();
191 }
192 if (getElementwise()) {
193 if (!isElementwise(linalgOp))
194 return emitSilenceableError() << "not elementwise";
195 return DiagnosedSilenceableFailure::success();
196 }
197 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198 Block &body = linalgOp->getRegion(0).front();
199 std::string message;
200 llvm::raw_string_ostream os(message);
201 bool result = linalg::detail::isContractionBody(
202 body,
203 [&](Operation *elem, Operation *red) {
204 return elem->getName().getStringRef() ==
205 cast<StringAttr>((*contractionOps)[0]).getValue() &&
206 red->getName().getStringRef() ==
207 cast<StringAttr>((*contractionOps)[1]).getValue();
208 },
209 os);
210 if (result)
211 return DiagnosedSilenceableFailure::success();
212 return emitSilenceableError() << "contraction: " << os.str();
213 }
214 return emitDefiniteFailure() << "unknown body condition";
215}
216
217LogicalResult transform::MatchStructuredBodyOp::verify() {
218 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219 getElementwise() + getContraction().has_value();
220
221 if (numOptions > 1) {
222 std::string attributeNames;
223 llvm::raw_string_ostream os(attributeNames);
224 llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
225 getPassthroughAttrName(),
226 getElementwiseAttrName(),
227 getContractionAttrName()},
228 os);
229 return emitOpError() << "only one of {" << os.str() << "} is allowed";
230 }
231
232 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
233 if (contractionAttr->size() != 2) {
234 return emitOpError() << "expects " << getContractionAttrName()
235 << " to contain two elements";
236 }
237 }
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// MatchStructuredClassifyContractionDimsOp
243//===----------------------------------------------------------------------===//
244
245DiagnosedSilenceableFailure
246transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
247 Operation *current, transform::TransformResults &results,
248 transform::TransformState &state) {
249 FailureOr<linalg::ContractionDimensions> contractionDims =
250 linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
251 if (failed(contractionDims))
252 return emitSilenceableError() << "could not infer contraction dimensions";
253
254 MLIRContext *context = current->getContext();
255 Builder builder(context);
256 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
257 return llvm::to_vector(
258 llvm::map_range(values, [&](unsigned value) -> Attribute {
259 return builder.getI64IntegerAttr(value);
260 }));
261 };
262 results.setParams(cast<OpResult>(getBatch()),
263 makeI64Attrs(contractionDims->batch));
264 results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
265 results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
266 results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
267 return DiagnosedSilenceableFailure::success();
268}
269
270//===----------------------------------------------------------------------===//
271// MatchStructuredClassifyConvolutionDimsOp
272//===----------------------------------------------------------------------===//
273
274DiagnosedSilenceableFailure
275transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
276 Operation *current, transform::TransformResults &results,
277 transform::TransformState &state) {
278 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
279 linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
280 if (failed(convolutionDims))
281 return emitSilenceableError() << "could not infer convolution dimensions";
282
283 MLIRContext *context = current->getContext();
284 Builder builder(context);
285 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
286 return llvm::to_vector(
287 llvm::map_range(values, [&](unsigned value) -> Attribute {
288 return builder.getI64IntegerAttr(value);
289 }));
290 };
291 results.setParams(cast<OpResult>(getBatch()),
292 makeI64Attrs(convolutionDims->batch));
293 results.setParams(cast<OpResult>(getOutputImage()),
294 makeI64Attrs(convolutionDims->outputImage));
295 results.setParams(cast<OpResult>(getOutputChannel()),
296 makeI64Attrs(convolutionDims->outputChannel));
297 results.setParams(cast<OpResult>(getFilterLoop()),
298 makeI64Attrs(convolutionDims->filterLoop));
299 results.setParams(cast<OpResult>(getInputChannel()),
300 makeI64Attrs(convolutionDims->inputChannel));
301 results.setParams(cast<OpResult>(getDepth()),
302 makeI64Attrs(convolutionDims->depth));
303
304 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
305 return llvm::to_vector(
306 llvm::map_range(values, [&](int64_t value) -> Attribute {
307 return builder.getI64IntegerAttr(value);
308 }));
309 };
310 results.setParams(cast<OpResult>(getStrides()),
311 makeI64AttrsFromI64(convolutionDims->strides));
312 results.setParams(cast<OpResult>(getDilations()),
313 makeI64AttrsFromI64(convolutionDims->dilations));
314 return DiagnosedSilenceableFailure::success();
315}
316
317//===----------------------------------------------------------------------===//
318// Utilities for structured match predicates.
319//===----------------------------------------------------------------------===//
320
321/// Checks if all values from `list` are also contained in `reference`. Returns
322/// a silenceable error with the given message at the given location when it is
323/// not the case. The error message must contain the "{0}" placeholder that
324/// will be substituted with the value from `list` that is not contained in
325/// `reference`.
326static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
327 ArrayRef<int64_t> list,
328 Location loc,
329 const char *message) {
330 for (int64_t value : list) {
331 if (llvm::any_of(Range&: reference, P: [&](unsigned ref) {
332 return static_cast<int64_t>(ref) == value;
333 })) {
334 continue;
335 }
336 return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value);
337 }
338 return DiagnosedSilenceableFailure::success();
339}
340
341//===----------------------------------------------------------------------===//
342// MatchStructuredDimOp
343//===----------------------------------------------------------------------===//
344
345DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
346 Operation *current, transform::TransformResults &results,
347 transform::TransformState &state) {
348 auto linalgOp = cast<linalg::LinalgOp>(current);
349 SmallVector<int64_t> dimensions;
350 DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
351 if (!diag.succeeded())
352 return diag;
353
354 // If asked to check for the kind of dimension, perform the check.
355 if (getParallel() || getReduction()) {
356 SmallVector<unsigned> reference;
357 if (getParallel())
358 linalgOp.getParallelDims(reference);
359 else if (getReduction())
360 linalgOp.getReductionDims(reference);
361
362 DiagnosedSilenceableFailure diag =
363 containsAll(reference, dimensions, getLoc(),
364 getParallel() ? "expects dimension #{0} to be parallel"
365 : "expects dimension #{0} to be reduction");
366 if (!diag.succeeded())
367 return diag;
368 }
369
370 // If not capturing, we are done here.
371 if (!getResult())
372 return diag;
373
374 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
375 Builder builder(current);
376 SmallVector<Attribute> captured = llvm::to_vector(
377 llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
378 return builder.getI64IntegerAttr(ranges[dim]);
379 }));
380 results.setParams(cast<OpResult>(getResult()), captured);
381 return DiagnosedSilenceableFailure::success();
382}
383
384DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
385 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
386 DiagnosedSilenceableFailure diag =
387 expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
388 getRawDimList(), op.getNumLoops(), dims);
389 if (diag.isSilenceableFailure()) {
390 diag.attachNote(op->getLoc())
391 << "while considering dimensions of this payload operation";
392 }
393 return diag;
394}
395
396LogicalResult transform::MatchStructuredDimOp::verify() {
397 if (getParallel() && getReduction()) {
398 return emitOpError() << "cannot request the same dimension to be both "
399 "parallel and reduction";
400 }
401 return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
402 getIsInverted(), getIsAll());
403}
404
405//===----------------------------------------------------------------------===//
406// MatchStructuredElementalBitwidthOp
407//===----------------------------------------------------------------------===//
408
409DiagnosedSilenceableFailure
410transform::MatchStructuredElementalBitwidthOp::matchValue(
411 Value current, transform::TransformResults &results,
412 transform::TransformState &state) {
413 auto setupResult = [&](int64_t bitwidth) {
414 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
415 results.setParams(cast<OpResult>(getResult()), {attr});
416 return DiagnosedSilenceableFailure::success();
417 };
418
419 Type type = current.getType();
420 if (type.isIntOrFloat())
421 return setupResult(type.getIntOrFloatBitWidth());
422
423 if (auto shapedType = dyn_cast<ShapedType>(type)) {
424 if (shapedType.getElementType().isIntOrFloat())
425 return setupResult(shapedType.getElementTypeBitWidth());
426 }
427 return emitSilenceableError()
428 << "unsupported type for bitwidth extraction: " << type;
429}
430
431//===----------------------------------------------------------------------===//
432// MatchStructuredInputOp
433//===----------------------------------------------------------------------===//
434
435DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
436 Operation *current, transform::TransformResults &results,
437 transform::TransformState &state) {
438 auto linalgOp = cast<linalg::LinalgOp>(current);
439 SmallVector<int64_t> positions;
440 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
441 if (!diag.succeeded())
442 return diag;
443
444 SmallVector<MappedValue> operandMapping;
445 operandMapping.reserve(positions.size());
446 for (int64_t position : positions) {
447 AffineMap indexingMap =
448 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
449 if (getPermutation() && !indexingMap.isPermutation()) {
450 return emitSilenceableError() << "the indexing map for input #"
451 << position << " is not a permutation";
452 }
453 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
454 return emitSilenceableError()
455 << "the indexing map for input #" << position
456 << " is not a projected permutation";
457 }
458
459 // If capture not requested, skip it.
460 if (!getResult())
461 continue;
462
463 if (isa<AffineMapParamType>(getResult().getType())) {
464 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
465 continue;
466 }
467
468 Value operand = linalgOp.getDpsInputOperand(position)->get();
469 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
470 operandMapping.emplace_back(operand);
471 continue;
472 }
473
474 Operation *operandProducer = operand.getDefiningOp();
475 if (!operandProducer) {
476 return emitSilenceableError()
477 << "input #" << position << " is not produced by an operation";
478 }
479 operandMapping.emplace_back(operandProducer);
480 }
481 if (getResult())
482 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
483 return DiagnosedSilenceableFailure::success();
484}
485
486DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
487 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
488 DiagnosedSilenceableFailure diag = expandTargetSpecification(
489 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
490 op.getNumDpsInputs(), positions);
491 if (diag.isSilenceableFailure()) {
492 diag.attachNote(op->getLoc())
493 << "while considering DPS inputs of this payload operation";
494 }
495 return diag;
496}
497
498/// Verifies a matcher op for structured input or output, specifically the
499/// attributes specifying the operand positions.
500template <typename OpTy>
501LogicalResult verifyStructuredOperandOp(OpTy op) {
502 if (op.getPermutation() && op.getProjectedPermutation()) {
503 return op.emitOpError()
504 << op.getPermutationAttrName() << " and "
505 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
506 }
507 if (op.getRawPositionList().size() > 1 && op.getResult()) {
508 return op.emitOpError()
509 << "cannot bind multiple inputs/inits to the same value";
510 }
511
512 return success();
513}
514
515LogicalResult transform::MatchStructuredInputOp::verify() {
516 if (failed(verifyStructuredOperandOp(*this)))
517 return failure();
518 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
519 getIsInverted(), getIsAll());
520}
521
522//===----------------------------------------------------------------------===//
523// MatchStructuredInitOp
524//===----------------------------------------------------------------------===//
525
526DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
527 Operation *current, transform::TransformResults &results,
528 transform::TransformState &state) {
529 auto linalgOp = cast<linalg::LinalgOp>(current);
530 SmallVector<int64_t> positions;
531 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
532 if (!diag.succeeded())
533 return diag;
534
535 SmallVector<MappedValue> operandMapping;
536 operandMapping.reserve(positions.size());
537 for (int64_t position : positions) {
538 AffineMap indexingMap =
539 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
540 if (getPermutation() && !indexingMap.isPermutation()) {
541 return emitSilenceableError() << "the indexing map for output(init) #"
542 << position << " is not a permutation";
543 }
544 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
545 return emitSilenceableError() << "the indexing map for output(init) #"
546 << position << " is not a permutation";
547 }
548
549 // If capture not requested, skip it.
550 if (!getResult())
551 continue;
552
553 if (isa<AffineMapParamType>(getResult().getType())) {
554 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
555 continue;
556 }
557
558 Value operand = linalgOp.getDpsInitOperand(position)->get();
559 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
560 operandMapping.emplace_back(operand);
561 continue;
562 }
563
564 Operation *operandProducer = operand.getDefiningOp();
565 if (!operandProducer) {
566 return emitSilenceableError() << "output(init) #" << position
567 << " is not produced by an operation";
568 }
569 operandMapping.emplace_back(operandProducer);
570 }
571 if (getResult())
572 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
573 return DiagnosedSilenceableFailure::success();
574}
575
576DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
577 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
578 DiagnosedSilenceableFailure diag = expandTargetSpecification(
579 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
580 op.getNumDpsInits(), positions);
581 if (diag.isSilenceableFailure()) {
582 diag.attachNote(op->getLoc())
583 << "while considering DPS inits (outputs) of this payload operation";
584 }
585 return diag;
586}
587
588LogicalResult transform::MatchStructuredInitOp::verify() {
589 if (failed(verifyStructuredOperandOp(*this)))
590 return failure();
591 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
592 getIsInverted(), getIsAll());
593}
594
595//===----------------------------------------------------------------------===//
596// MatchStructuredNumInputsOp
597//===----------------------------------------------------------------------===//
598
599DiagnosedSilenceableFailure
600transform::MatchStructuredNumInputsOp::matchOperation(
601 Operation *current, transform::TransformResults &results,
602 transform::TransformState &state) {
603 auto linalgOp = cast<linalg::LinalgOp>(current);
604 Attribute attr =
605 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
606 results.setParams(cast<OpResult>(getResult()), {attr});
607 return DiagnosedSilenceableFailure::success();
608}
609
610//===----------------------------------------------------------------------===//
611// MatchStructuredNumInitsOp
612//===----------------------------------------------------------------------===//
613
614DiagnosedSilenceableFailure
615transform::MatchStructuredNumInitsOp::matchOperation(
616 Operation *current, transform::TransformResults &results,
617 transform::TransformState &state) {
618 auto linalgOp = cast<linalg::LinalgOp>(current);
619 Attribute attr =
620 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
621 results.setParams(cast<OpResult>(getResult()), {attr});
622 return DiagnosedSilenceableFailure::success();
623}
624
625//===----------------------------------------------------------------------===//
626// MatchStructuredRankOp
627//===----------------------------------------------------------------------===//
628
629DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
630 Operation *current, transform::TransformResults &results,
631 transform::TransformState &state) {
632 auto linalgOp = cast<linalg::LinalgOp>(current);
633 int64_t numLoops = linalgOp.getNumLoops();
634 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
635 results.setParams(cast<OpResult>(getRank()), {attr});
636 return DiagnosedSilenceableFailure::success();
637}
638
639//===----------------------------------------------------------------------===//
640// MatchStructuredResultOp
641//===----------------------------------------------------------------------===//
642
643DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
644 Operation *op, transform::TransformResults &results,
645 transform::TransformState &state) {
646 auto linalgOp = cast<linalg::LinalgOp>(op);
647 int64_t position;
648 DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
649 if (!diag.succeeded())
650 return diag;
651
652 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
653 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
654 results.setValues(cast<OpResult>(getResult()), {result});
655 return DiagnosedSilenceableFailure::success();
656 }
657
658 if (result.getUsers().empty()) {
659 return emitSilenceableError()
660 << "no users of the result #" << getPosition();
661 }
662 Operation *firstUser = *result.getUsers().begin();
663 if (getAny()) {
664 results.set(cast<OpResult>(getResult()), {firstUser});
665 return DiagnosedSilenceableFailure::success();
666 }
667 if (getSingle()) {
668 if (!llvm::hasSingleElement(result.getUsers())) {
669 return emitSilenceableError()
670 << "more than one result user with single user requested";
671 }
672 results.set(cast<OpResult>(getResult()), {firstUser});
673 return DiagnosedSilenceableFailure::success();
674 }
675
676 return emitDefiniteFailure() << "unknown sub-predicate";
677}
678
679DiagnosedSilenceableFailure
680transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
681 int64_t &position) {
682 auto rawPosition = static_cast<int64_t>(getPosition());
683 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
684 if (position >= op.getNumDpsInits() || position < 0) {
685 return emitSilenceableError()
686 << "position " << rawPosition
687 << " overflows the number of results(ints) of the payload operation";
688 }
689 return DiagnosedSilenceableFailure::success();
690}
691
692LogicalResult transform::MatchStructuredResultOp::verify() {
693 if ((getAny() || getSingle()) ^
694 isa<TransformHandleTypeInterface>(getResult().getType())) {
695 return emitOpError() << "expects either the any/single keyword or the type "
696 "value handle result type";
697 }
698 if (getAny() && getSingle()) {
699 return emitOpError() << "'any' and 'single' are mutually exclusive";
700 }
701 return success();
702}
703
704//===----------------------------------------------------------------------===//
705// MatchStructuredYieldOp
706//===----------------------------------------------------------------------===//
707
708void transform::MatchStructuredYieldOp::getEffects(
709 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
710 onlyReadsHandle(getHandles(), effects);
711 onlyReadsPayload(effects);
712}
713
714void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
715 OperationState &state) {
716 build(builder, state, ValueRange());
717}
718
719#define GET_OP_CLASSES
720#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
721

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp