1//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
10#include "mlir/Analysis/SliceAnalysis.h"
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
13#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/Transform/IR/TransformTypes.h"
16#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/InterleavedRange.h"
22
23using namespace mlir;
24
25#define DEBUG_TYPE "linalg-transforms"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27
28//===----------------------------------------------------------------------===//
29// StructuredMatchOp
30//===----------------------------------------------------------------------===//
31
32DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
33 Operation *current, transform::TransformResults &results,
34 transform::TransformState &state) {
35 // First, check if the payload operation is a structured Linalg operation.
36 if (!isa<linalg::LinalgOp>(current)) {
37 if (getFailurePropagationMode().value_or(
38 FailurePropagationMode::Propagate) ==
39 FailurePropagationMode::Propagate) {
40 return emitSilenceableError() << "expected a Linalg op";
41 }
42 // If errors are suppressed, succeed and set all results to empty lists.
43 LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
44 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
45 return DiagnosedSilenceableFailure::success();
46 }
47
48 // Bind `current` to the block argument.
49 auto scope = state.make_region_scope(getBodyRegion());
50 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
51 MappedValue(current)))) {
52 return DiagnosedSilenceableFailure::definiteFailure();
53 }
54
55 for (Operation &nested : getBody()->without_terminator()) {
56 DiagnosedSilenceableFailure diag =
57 state.applyTransform(cast<TransformOpInterface>(nested));
58 if (diag.isDefiniteFailure())
59 return diag;
60 if (diag.succeeded())
61 continue;
62
63 // If propagating errors, do this immediately.
64 assert(diag.isSilenceableFailure());
65 if (getFailurePropagationMode().value_or(
66 FailurePropagationMode::Propagate) ==
67 FailurePropagationMode::Propagate) {
68 return diag;
69 }
70
71 // If suppressing errors, print the message into the debug stream before
72 // silencing it. Then set all results value that are already known.
73 // Results come from the terminator operands, which may be defined in the
74 // (single) block of this operation or above it. When they are defined
75 // above, they are known to be mapped at this point per SSA dominance.
76 // When they are defined in this block, we additionally check if we have
77 // already applied the operation that defines them. If not, the
78 // corresponding results will be set to empty lists.
79 LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
80 << "\n");
81 (void)diag.silence();
82 SmallVector<OpOperand *> undefinedOperands;
83 for (OpOperand &terminatorOperand :
84 getBody()->getTerminator()->getOpOperands()) {
85 Operation *definingOp = terminatorOperand.get().getDefiningOp();
86 if (!definingOp)
87 continue;
88 if (definingOp->getBlock() != getBody())
89 continue;
90 if (definingOp->isBeforeInBlock(&nested))
91 continue;
92
93 undefinedOperands.push_back(&terminatorOperand);
94 }
95
96 SmallVector<SmallVector<transform::MappedValue>> mappings;
97 auto filtered = llvm::make_filter_range(
98 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
99 return !llvm::is_contained(undefinedOperands, &opOperand);
100 });
101 SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
102 filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
103 detail::prepareValueMappings(mappings, definedOperands, state);
104 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
105 results.setMappedValues(getResults()[operand.getOperandNumber()],
106 mapping);
107 }
108 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
109 return DiagnosedSilenceableFailure::success();
110 }
111
112 // Set the results.
113 detail::forwardTerminatorOperands(getBody(), state, results);
114 return DiagnosedSilenceableFailure::success();
115}
116
117void transform::MatchStructuredOp::getEffects(
118 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
119 onlyReadsHandle(getCurrentMutable(), effects);
120 onlyReadsPayload(effects);
121 producesHandle(getOperation()->getOpResults(), effects);
122}
123
124LogicalResult transform::MatchStructuredOp::verify() {
125 if (getBody()->getNumArguments() != 1)
126 return emitOpError() << "expected one body argument";
127 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
128 return emitOpError() << "expected body argument to implement "
129 "TransformHandleTypeInterface";
130 }
131 for (Operation &nested : getBody()->without_terminator()) {
132 if (isa<MatchOpInterface>(nested))
133 continue;
134 InFlightDiagnostic diag =
135 emitOpError()
136 << "expects nested operations to implement MatchOpInterface";
137 diag.attachNote(nested.getLoc()) << "offending operation";
138 return diag;
139 }
140 return success();
141}
142
143//===----------------------------------------------------------------------===//
144// StructuredOpPredicateOpTrait
145//===----------------------------------------------------------------------===//
146
147LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
148 Operation *op, Value structuredOpHandle) {
149 if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
150 return op->emitOpError() << "expects parent op to be '"
151 << MatchStructuredOp::getOperationName() << "'";
152 }
153
154 // Bail out here, let the verifier of the parent complain.
155 Operation *parent = op->getParentOp();
156 if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() ||
157 parent->getRegion(index: 0).front().getNumArguments() < 1)
158 return success();
159
160 if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) {
161 return op->emitOpError()
162 << "expected predicate to apply to the surrounding structured op";
163 }
164 return success();
165}
166
167//===----------------------------------------------------------------------===//
168// MatchStructuredBodyOp
169//===----------------------------------------------------------------------===//
170
171DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
172 Operation *current, transform::TransformResults &results,
173 transform::TransformState &state) {
174 auto linalgOp = cast<linalg::LinalgOp>(current);
175 if (std::optional<uint64_t> position = getReductionPosition()) {
176 SmallVector<Operation *> combinerOps;
177 if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
178 combinerOps)) {
179 return emitSilenceableError() << "could not match reduction";
180 }
181 if (combinerOps.size() != 1) {
182 return emitSilenceableError() << "reduction combiner is not a single op";
183 }
184 return DiagnosedSilenceableFailure::success();
185 }
186 if (getPassthrough()) {
187 Block &body = linalgOp->getRegion(0).front();
188 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
189 return emitSilenceableError() << "not a passthrough";
190 }
191 return DiagnosedSilenceableFailure::success();
192 }
193 if (getElementwise()) {
194 if (!isElementwise(linalgOp))
195 return emitSilenceableError() << "not elementwise";
196 return DiagnosedSilenceableFailure::success();
197 }
198 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
199 Block &body = linalgOp->getRegion(0).front();
200 std::string message;
201 llvm::raw_string_ostream os(message);
202 bool result = linalg::detail::isContractionBody(
203 body,
204 [&](Operation *elem, Operation *red) {
205 return elem->getName().getStringRef() ==
206 cast<StringAttr>((*contractionOps)[0]).getValue() &&
207 red->getName().getStringRef() ==
208 cast<StringAttr>((*contractionOps)[1]).getValue();
209 },
210 os);
211 if (result)
212 return DiagnosedSilenceableFailure::success();
213 return emitSilenceableError() << "contraction: " << message;
214 }
215 return emitDefiniteFailure() << "unknown body condition";
216}
217
218LogicalResult transform::MatchStructuredBodyOp::verify() {
219 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
220 getElementwise() + getContraction().has_value();
221
222 if (numOptions > 1) {
223 StringAttr attributeNames[] = {
224 getReductionPositionAttrName(), getPassthroughAttrName(),
225 getElementwiseAttrName(), getContractionAttrName()};
226 return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
227 << "} is allowed";
228 }
229
230 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
231 if (contractionAttr->size() != 2) {
232 return emitOpError() << "expects " << getContractionAttrName()
233 << " to contain two elements";
234 }
235 }
236 return success();
237}
238
239//===----------------------------------------------------------------------===//
240// MatchStructuredClassifyContractionDimsOp
241//===----------------------------------------------------------------------===//
242
243DiagnosedSilenceableFailure
244transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
245 Operation *current, transform::TransformResults &results,
246 transform::TransformState &state) {
247 FailureOr<linalg::ContractionDimensions> contractionDims =
248 linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
249 if (failed(contractionDims))
250 return emitSilenceableError() << "could not infer contraction dimensions";
251
252 MLIRContext *context = current->getContext();
253 Builder builder(context);
254 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
255 return llvm::to_vector(
256 llvm::map_range(values, [&](unsigned value) -> Attribute {
257 return builder.getI64IntegerAttr(value);
258 }));
259 };
260 results.setParams(cast<OpResult>(getBatch()),
261 makeI64Attrs(contractionDims->batch));
262 results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
263 results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
264 results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
265 return DiagnosedSilenceableFailure::success();
266}
267
268//===----------------------------------------------------------------------===//
269// MatchStructuredClassifyConvolutionDimsOp
270//===----------------------------------------------------------------------===//
271
272DiagnosedSilenceableFailure
273transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
274 Operation *current, transform::TransformResults &results,
275 transform::TransformState &state) {
276 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
277 linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
278 if (failed(convolutionDims))
279 return emitSilenceableError() << "could not infer convolution dimensions";
280
281 MLIRContext *context = current->getContext();
282 Builder builder(context);
283 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
284 return llvm::to_vector(
285 llvm::map_range(values, [&](unsigned value) -> Attribute {
286 return builder.getI64IntegerAttr(value);
287 }));
288 };
289 results.setParams(cast<OpResult>(getBatch()),
290 makeI64Attrs(convolutionDims->batch));
291 results.setParams(cast<OpResult>(getOutputImage()),
292 makeI64Attrs(convolutionDims->outputImage));
293 results.setParams(cast<OpResult>(getOutputChannel()),
294 makeI64Attrs(convolutionDims->outputChannel));
295 results.setParams(cast<OpResult>(getFilterLoop()),
296 makeI64Attrs(convolutionDims->filterLoop));
297 results.setParams(cast<OpResult>(getInputChannel()),
298 makeI64Attrs(convolutionDims->inputChannel));
299 results.setParams(cast<OpResult>(getDepth()),
300 makeI64Attrs(convolutionDims->depth));
301
302 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
303 return llvm::to_vector(
304 llvm::map_range(values, [&](int64_t value) -> Attribute {
305 return builder.getI64IntegerAttr(value);
306 }));
307 };
308 results.setParams(cast<OpResult>(getStrides()),
309 makeI64AttrsFromI64(convolutionDims->strides));
310 results.setParams(cast<OpResult>(getDilations()),
311 makeI64AttrsFromI64(convolutionDims->dilations));
312 return DiagnosedSilenceableFailure::success();
313}
314
315//===----------------------------------------------------------------------===//
316// Utilities for structured match predicates.
317//===----------------------------------------------------------------------===//
318
319/// Checks if all values from `list` are also contained in `reference`. Returns
320/// a silenceable error with the given message at the given location when it is
321/// not the case. The error message must contain the "{0}" placeholder that
322/// will be substituted with the value from `list` that is not contained in
323/// `reference`.
324static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
325 ArrayRef<int64_t> list,
326 Location loc,
327 const char *message) {
328 for (int64_t value : list) {
329 if (llvm::any_of(Range&: reference, P: [&](unsigned ref) {
330 return static_cast<int64_t>(ref) == value;
331 })) {
332 continue;
333 }
334 return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value);
335 }
336 return DiagnosedSilenceableFailure::success();
337}
338
339//===----------------------------------------------------------------------===//
340// MatchStructuredDimOp
341//===----------------------------------------------------------------------===//
342
343DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
344 Operation *current, transform::TransformResults &results,
345 transform::TransformState &state) {
346 auto linalgOp = cast<linalg::LinalgOp>(current);
347 SmallVector<int64_t> dimensions;
348 DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
349 if (!diag.succeeded())
350 return diag;
351
352 // If asked to check for the kind of dimension, perform the check.
353 if (getParallel() || getReduction()) {
354 SmallVector<unsigned> reference;
355 if (getParallel())
356 linalgOp.getParallelDims(reference);
357 else if (getReduction())
358 linalgOp.getReductionDims(reference);
359
360 DiagnosedSilenceableFailure diag =
361 containsAll(reference, dimensions, getLoc(),
362 getParallel() ? "expects dimension #{0} to be parallel"
363 : "expects dimension #{0} to be reduction");
364 if (!diag.succeeded())
365 return diag;
366 }
367
368 // If not capturing, we are done here.
369 if (!getResult())
370 return diag;
371
372 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
373 Builder builder(current);
374 SmallVector<Attribute> captured = llvm::to_vector(
375 llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
376 return builder.getI64IntegerAttr(ranges[dim]);
377 }));
378 results.setParams(cast<OpResult>(getResult()), captured);
379 return DiagnosedSilenceableFailure::success();
380}
381
382DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
383 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
384 DiagnosedSilenceableFailure diag =
385 expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
386 getRawDimList(), op.getNumLoops(), dims);
387 if (diag.isSilenceableFailure()) {
388 diag.attachNote(op->getLoc())
389 << "while considering dimensions of this payload operation";
390 }
391 return diag;
392}
393
394LogicalResult transform::MatchStructuredDimOp::verify() {
395 if (getParallel() && getReduction()) {
396 return emitOpError() << "cannot request the same dimension to be both "
397 "parallel and reduction";
398 }
399 return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
400 getIsInverted(), getIsAll());
401}
402
403//===----------------------------------------------------------------------===//
404// MatchStructuredElementalBitwidthOp
405//===----------------------------------------------------------------------===//
406
407DiagnosedSilenceableFailure
408transform::MatchStructuredElementalBitwidthOp::matchValue(
409 Value current, transform::TransformResults &results,
410 transform::TransformState &state) {
411 auto setupResult = [&](int64_t bitwidth) {
412 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
413 results.setParams(cast<OpResult>(getResult()), {attr});
414 return DiagnosedSilenceableFailure::success();
415 };
416
417 Type type = current.getType();
418 if (type.isIntOrFloat())
419 return setupResult(type.getIntOrFloatBitWidth());
420
421 if (auto shapedType = dyn_cast<ShapedType>(type)) {
422 if (shapedType.getElementType().isIntOrFloat())
423 return setupResult(shapedType.getElementTypeBitWidth());
424 }
425 return emitSilenceableError()
426 << "unsupported type for bitwidth extraction: " << type;
427}
428
429//===----------------------------------------------------------------------===//
430// MatchStructuredInputOp
431//===----------------------------------------------------------------------===//
432
433DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
434 Operation *current, transform::TransformResults &results,
435 transform::TransformState &state) {
436 auto linalgOp = cast<linalg::LinalgOp>(current);
437 SmallVector<int64_t> positions;
438 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
439 if (!diag.succeeded())
440 return diag;
441
442 SmallVector<MappedValue> operandMapping;
443 operandMapping.reserve(positions.size());
444 for (int64_t position : positions) {
445 AffineMap indexingMap =
446 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
447 if (getPermutation() && !indexingMap.isPermutation()) {
448 return emitSilenceableError() << "the indexing map for input #"
449 << position << " is not a permutation";
450 }
451 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
452 return emitSilenceableError()
453 << "the indexing map for input #" << position
454 << " is not a projected permutation";
455 }
456
457 // If capture not requested, skip it.
458 if (!getResult())
459 continue;
460
461 if (isa<AffineMapParamType>(getResult().getType())) {
462 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
463 continue;
464 }
465
466 Value operand = linalgOp.getDpsInputOperand(position)->get();
467 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
468 operandMapping.emplace_back(operand);
469 continue;
470 }
471
472 Operation *operandProducer = operand.getDefiningOp();
473 if (!operandProducer) {
474 return emitSilenceableError()
475 << "input #" << position << " is not produced by an operation";
476 }
477 operandMapping.emplace_back(operandProducer);
478 }
479 if (getResult())
480 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
481 return DiagnosedSilenceableFailure::success();
482}
483
484DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
485 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
486 DiagnosedSilenceableFailure diag = expandTargetSpecification(
487 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
488 op.getNumDpsInputs(), positions);
489 if (diag.isSilenceableFailure()) {
490 diag.attachNote(op->getLoc())
491 << "while considering DPS inputs of this payload operation";
492 }
493 return diag;
494}
495
496/// Verifies a matcher op for structured input or output, specifically the
497/// attributes specifying the operand positions.
498template <typename OpTy>
499LogicalResult verifyStructuredOperandOp(OpTy op) {
500 if (op.getPermutation() && op.getProjectedPermutation()) {
501 return op.emitOpError()
502 << op.getPermutationAttrName() << " and "
503 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
504 }
505 if (op.getRawPositionList().size() > 1 && op.getResult()) {
506 return op.emitOpError()
507 << "cannot bind multiple inputs/inits to the same value";
508 }
509
510 return success();
511}
512
513LogicalResult transform::MatchStructuredInputOp::verify() {
514 if (failed(verifyStructuredOperandOp(*this)))
515 return failure();
516 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
517 getIsInverted(), getIsAll());
518}
519
520//===----------------------------------------------------------------------===//
521// MatchStructuredInitOp
522//===----------------------------------------------------------------------===//
523
524DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
525 Operation *current, transform::TransformResults &results,
526 transform::TransformState &state) {
527 auto linalgOp = cast<linalg::LinalgOp>(current);
528 SmallVector<int64_t> positions;
529 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
530 if (!diag.succeeded())
531 return diag;
532
533 SmallVector<MappedValue> operandMapping;
534 operandMapping.reserve(positions.size());
535 for (int64_t position : positions) {
536 AffineMap indexingMap =
537 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
538 if (getPermutation() && !indexingMap.isPermutation()) {
539 return emitSilenceableError() << "the indexing map for output(init) #"
540 << position << " is not a permutation";
541 }
542 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
543 return emitSilenceableError() << "the indexing map for output(init) #"
544 << position << " is not a permutation";
545 }
546
547 // If capture not requested, skip it.
548 if (!getResult())
549 continue;
550
551 if (isa<AffineMapParamType>(getResult().getType())) {
552 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
553 continue;
554 }
555
556 Value operand = linalgOp.getDpsInitOperand(position)->get();
557 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
558 operandMapping.emplace_back(operand);
559 continue;
560 }
561
562 Operation *operandProducer = operand.getDefiningOp();
563 if (!operandProducer) {
564 return emitSilenceableError() << "output(init) #" << position
565 << " is not produced by an operation";
566 }
567 operandMapping.emplace_back(operandProducer);
568 }
569 if (getResult())
570 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
571 return DiagnosedSilenceableFailure::success();
572}
573
574DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
575 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
576 DiagnosedSilenceableFailure diag = expandTargetSpecification(
577 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
578 op.getNumDpsInits(), positions);
579 if (diag.isSilenceableFailure()) {
580 diag.attachNote(op->getLoc())
581 << "while considering DPS inits (outputs) of this payload operation";
582 }
583 return diag;
584}
585
586LogicalResult transform::MatchStructuredInitOp::verify() {
587 if (failed(verifyStructuredOperandOp(*this)))
588 return failure();
589 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
590 getIsInverted(), getIsAll());
591}
592
593//===----------------------------------------------------------------------===//
594// MatchStructuredNumInputsOp
595//===----------------------------------------------------------------------===//
596
597DiagnosedSilenceableFailure
598transform::MatchStructuredNumInputsOp::matchOperation(
599 Operation *current, transform::TransformResults &results,
600 transform::TransformState &state) {
601 auto linalgOp = cast<linalg::LinalgOp>(current);
602 Attribute attr =
603 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
604 results.setParams(cast<OpResult>(getResult()), {attr});
605 return DiagnosedSilenceableFailure::success();
606}
607
608//===----------------------------------------------------------------------===//
609// MatchStructuredNumInitsOp
610//===----------------------------------------------------------------------===//
611
612DiagnosedSilenceableFailure
613transform::MatchStructuredNumInitsOp::matchOperation(
614 Operation *current, transform::TransformResults &results,
615 transform::TransformState &state) {
616 auto linalgOp = cast<linalg::LinalgOp>(current);
617 Attribute attr =
618 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
619 results.setParams(cast<OpResult>(getResult()), {attr});
620 return DiagnosedSilenceableFailure::success();
621}
622
623//===----------------------------------------------------------------------===//
624// MatchStructuredRankOp
625//===----------------------------------------------------------------------===//
626
627DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
628 Operation *current, transform::TransformResults &results,
629 transform::TransformState &state) {
630 auto linalgOp = cast<linalg::LinalgOp>(current);
631 int64_t numLoops = linalgOp.getNumLoops();
632 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
633 results.setParams(cast<OpResult>(getRank()), {attr});
634 return DiagnosedSilenceableFailure::success();
635}
636
637//===----------------------------------------------------------------------===//
638// MatchStructuredResultOp
639//===----------------------------------------------------------------------===//
640
641DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
642 Operation *op, transform::TransformResults &results,
643 transform::TransformState &state) {
644 auto linalgOp = cast<linalg::LinalgOp>(op);
645 int64_t position;
646 DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
647 if (!diag.succeeded())
648 return diag;
649
650 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
651 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
652 results.setValues(cast<OpResult>(getResult()), {result});
653 return DiagnosedSilenceableFailure::success();
654 }
655
656 if (result.getUsers().empty()) {
657 return emitSilenceableError()
658 << "no users of the result #" << getPosition();
659 }
660 Operation *firstUser = *result.getUsers().begin();
661 if (getAny()) {
662 results.set(cast<OpResult>(getResult()), {firstUser});
663 return DiagnosedSilenceableFailure::success();
664 }
665 if (getSingle()) {
666 if (!llvm::hasSingleElement(result.getUsers())) {
667 return emitSilenceableError()
668 << "more than one result user with single user requested";
669 }
670 results.set(cast<OpResult>(getResult()), {firstUser});
671 return DiagnosedSilenceableFailure::success();
672 }
673
674 return emitDefiniteFailure() << "unknown sub-predicate";
675}
676
677DiagnosedSilenceableFailure
678transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
679 int64_t &position) {
680 auto rawPosition = static_cast<int64_t>(getPosition());
681 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
682 if (position >= op.getNumDpsInits() || position < 0) {
683 return emitSilenceableError()
684 << "position " << rawPosition
685 << " overflows the number of results(ints) of the payload operation";
686 }
687 return DiagnosedSilenceableFailure::success();
688}
689
690LogicalResult transform::MatchStructuredResultOp::verify() {
691 if ((getAny() || getSingle()) ^
692 isa<TransformHandleTypeInterface>(getResult().getType())) {
693 return emitOpError() << "expects either the any/single keyword or the type "
694 "value handle result type";
695 }
696 if (getAny() && getSingle()) {
697 return emitOpError() << "'any' and 'single' are mutually exclusive";
698 }
699 return success();
700}
701
702//===----------------------------------------------------------------------===//
703// MatchStructuredYieldOp
704//===----------------------------------------------------------------------===//
705
706void transform::MatchStructuredYieldOp::getEffects(
707 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
708 onlyReadsHandle(getHandlesMutable(), effects);
709 onlyReadsPayload(effects);
710}
711
712void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
713 OperationState &state) {
714 build(builder, state, ValueRange());
715}
716
717#define GET_OP_CLASSES
718#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
719

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp