1 | //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" |
10 | #include "mlir/Analysis/SliceAnalysis.h" |
11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
12 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
13 | #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
16 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/Interfaces/FunctionImplementation.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/Support/InterleavedRange.h" |
22 | |
23 | using namespace mlir; |
24 | |
25 | #define DEBUG_TYPE "linalg-transforms" |
26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // StructuredMatchOp |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( |
33 | Operation *current, transform::TransformResults &results, |
34 | transform::TransformState &state) { |
35 | // First, check if the payload operation is a structured Linalg operation. |
36 | if (!isa<linalg::LinalgOp>(current)) { |
37 | if (getFailurePropagationMode().value_or( |
38 | FailurePropagationMode::Propagate) == |
39 | FailurePropagationMode::Propagate) { |
40 | return emitSilenceableError() << "expected a Linalg op" ; |
41 | } |
42 | // If errors are suppressed, succeed and set all results to empty lists. |
43 | LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op" ); |
44 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
45 | return DiagnosedSilenceableFailure::success(); |
46 | } |
47 | |
48 | // Bind `current` to the block argument. |
49 | auto scope = state.make_region_scope(getBodyRegion()); |
50 | if (failed(state.mapBlockArgument(getBody()->getArgument(0), |
51 | MappedValue(current)))) { |
52 | return DiagnosedSilenceableFailure::definiteFailure(); |
53 | } |
54 | |
55 | for (Operation &nested : getBody()->without_terminator()) { |
56 | DiagnosedSilenceableFailure diag = |
57 | state.applyTransform(cast<TransformOpInterface>(nested)); |
58 | if (diag.isDefiniteFailure()) |
59 | return diag; |
60 | if (diag.succeeded()) |
61 | continue; |
62 | |
63 | // If propagating errors, do this immediately. |
64 | assert(diag.isSilenceableFailure()); |
65 | if (getFailurePropagationMode().value_or( |
66 | FailurePropagationMode::Propagate) == |
67 | FailurePropagationMode::Propagate) { |
68 | return diag; |
69 | } |
70 | |
71 | // If suppressing errors, print the message into the debug stream before |
72 | // silencing it. Then set all results value that are already known. |
73 | // Results come from the terminator operands, which may be defined in the |
74 | // (single) block of this operation or above it. When they are defined |
75 | // above, they are known to be mapped at this point per SSA dominance. |
76 | // When they are defined in this block, we additionally check if we have |
77 | // already applied the operation that defines them. If not, the |
78 | // corresponding results will be set to empty lists. |
79 | LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() |
80 | << "\n" ); |
81 | (void)diag.silence(); |
82 | SmallVector<OpOperand *> undefinedOperands; |
83 | for (OpOperand &terminatorOperand : |
84 | getBody()->getTerminator()->getOpOperands()) { |
85 | Operation *definingOp = terminatorOperand.get().getDefiningOp(); |
86 | if (!definingOp) |
87 | continue; |
88 | if (definingOp->getBlock() != getBody()) |
89 | continue; |
90 | if (definingOp->isBeforeInBlock(&nested)) |
91 | continue; |
92 | |
93 | undefinedOperands.push_back(&terminatorOperand); |
94 | } |
95 | |
96 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
97 | auto filtered = llvm::make_filter_range( |
98 | getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { |
99 | return !llvm::is_contained(undefinedOperands, &opOperand); |
100 | }); |
101 | SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range( |
102 | filtered, [](OpOperand &opOperand) { return opOperand.get(); })); |
103 | detail::prepareValueMappings(mappings, definedOperands, state); |
104 | for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { |
105 | results.setMappedValues(getResults()[operand.getOperandNumber()], |
106 | mapping); |
107 | } |
108 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
109 | return DiagnosedSilenceableFailure::success(); |
110 | } |
111 | |
112 | // Set the results. |
113 | detail::forwardTerminatorOperands(getBody(), state, results); |
114 | return DiagnosedSilenceableFailure::success(); |
115 | } |
116 | |
117 | void transform::MatchStructuredOp::getEffects( |
118 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
119 | onlyReadsHandle(getCurrentMutable(), effects); |
120 | onlyReadsPayload(effects); |
121 | producesHandle(getOperation()->getOpResults(), effects); |
122 | } |
123 | |
124 | LogicalResult transform::MatchStructuredOp::verify() { |
125 | if (getBody()->getNumArguments() != 1) |
126 | return emitOpError() << "expected one body argument" ; |
127 | if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) { |
128 | return emitOpError() << "expected body argument to implement " |
129 | "TransformHandleTypeInterface" ; |
130 | } |
131 | for (Operation &nested : getBody()->without_terminator()) { |
132 | if (isa<MatchOpInterface>(nested)) |
133 | continue; |
134 | InFlightDiagnostic diag = |
135 | emitOpError() |
136 | << "expects nested operations to implement MatchOpInterface" ; |
137 | diag.attachNote(nested.getLoc()) << "offending operation" ; |
138 | return diag; |
139 | } |
140 | return success(); |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // StructuredOpPredicateOpTrait |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( |
148 | Operation *op, Value structuredOpHandle) { |
149 | if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) { |
150 | return op->emitOpError() << "expects parent op to be '" |
151 | << MatchStructuredOp::getOperationName() << "'" ; |
152 | } |
153 | |
154 | // Bail out here, let the verifier of the parent complain. |
155 | Operation *parent = op->getParentOp(); |
156 | if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() || |
157 | parent->getRegion(index: 0).front().getNumArguments() < 1) |
158 | return success(); |
159 | |
160 | if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) { |
161 | return op->emitOpError() |
162 | << "expected predicate to apply to the surrounding structured op" ; |
163 | } |
164 | return success(); |
165 | } |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // MatchStructuredBodyOp |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( |
172 | Operation *current, transform::TransformResults &results, |
173 | transform::TransformState &state) { |
174 | auto linalgOp = cast<linalg::LinalgOp>(current); |
175 | if (std::optional<uint64_t> position = getReductionPosition()) { |
176 | SmallVector<Operation *> combinerOps; |
177 | if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, |
178 | combinerOps)) { |
179 | return emitSilenceableError() << "could not match reduction" ; |
180 | } |
181 | if (combinerOps.size() != 1) { |
182 | return emitSilenceableError() << "reduction combiner is not a single op" ; |
183 | } |
184 | return DiagnosedSilenceableFailure::success(); |
185 | } |
186 | if (getPassthrough()) { |
187 | Block &body = linalgOp->getRegion(0).front(); |
188 | if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { |
189 | return emitSilenceableError() << "not a passthrough" ; |
190 | } |
191 | return DiagnosedSilenceableFailure::success(); |
192 | } |
193 | if (getElementwise()) { |
194 | if (!isElementwise(linalgOp)) |
195 | return emitSilenceableError() << "not elementwise" ; |
196 | return DiagnosedSilenceableFailure::success(); |
197 | } |
198 | if (std::optional<ArrayAttr> contractionOps = getContraction()) { |
199 | Block &body = linalgOp->getRegion(0).front(); |
200 | std::string message; |
201 | llvm::raw_string_ostream os(message); |
202 | bool result = linalg::detail::isContractionBody( |
203 | body, |
204 | [&](Operation *elem, Operation *red) { |
205 | return elem->getName().getStringRef() == |
206 | cast<StringAttr>((*contractionOps)[0]).getValue() && |
207 | red->getName().getStringRef() == |
208 | cast<StringAttr>((*contractionOps)[1]).getValue(); |
209 | }, |
210 | os); |
211 | if (result) |
212 | return DiagnosedSilenceableFailure::success(); |
213 | return emitSilenceableError() << "contraction: " << message; |
214 | } |
215 | return emitDefiniteFailure() << "unknown body condition" ; |
216 | } |
217 | |
218 | LogicalResult transform::MatchStructuredBodyOp::verify() { |
219 | int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + |
220 | getElementwise() + getContraction().has_value(); |
221 | |
222 | if (numOptions > 1) { |
223 | StringAttr attributeNames[] = { |
224 | getReductionPositionAttrName(), getPassthroughAttrName(), |
225 | getElementwiseAttrName(), getContractionAttrName()}; |
226 | return emitOpError() << "only one of {" << llvm::interleaved(attributeNames) |
227 | << "} is allowed" ; |
228 | } |
229 | |
230 | if (std::optional<ArrayAttr> contractionAttr = getContraction()) { |
231 | if (contractionAttr->size() != 2) { |
232 | return emitOpError() << "expects " << getContractionAttrName() |
233 | << " to contain two elements" ; |
234 | } |
235 | } |
236 | return success(); |
237 | } |
238 | |
239 | //===----------------------------------------------------------------------===// |
240 | // MatchStructuredClassifyContractionDimsOp |
241 | //===----------------------------------------------------------------------===// |
242 | |
243 | DiagnosedSilenceableFailure |
244 | transform::MatchStructuredClassifyContractionDimsOp::matchOperation( |
245 | Operation *current, transform::TransformResults &results, |
246 | transform::TransformState &state) { |
247 | FailureOr<linalg::ContractionDimensions> contractionDims = |
248 | linalg::inferContractionDims(cast<linalg::LinalgOp>(current)); |
249 | if (failed(contractionDims)) |
250 | return emitSilenceableError() << "could not infer contraction dimensions" ; |
251 | |
252 | MLIRContext *context = current->getContext(); |
253 | Builder builder(context); |
254 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
255 | return llvm::to_vector( |
256 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
257 | return builder.getI64IntegerAttr(value); |
258 | })); |
259 | }; |
260 | results.setParams(cast<OpResult>(getBatch()), |
261 | makeI64Attrs(contractionDims->batch)); |
262 | results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m)); |
263 | results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n)); |
264 | results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k)); |
265 | return DiagnosedSilenceableFailure::success(); |
266 | } |
267 | |
268 | //===----------------------------------------------------------------------===// |
269 | // MatchStructuredClassifyConvolutionDimsOp |
270 | //===----------------------------------------------------------------------===// |
271 | |
272 | DiagnosedSilenceableFailure |
273 | transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( |
274 | Operation *current, transform::TransformResults &results, |
275 | transform::TransformState &state) { |
276 | FailureOr<linalg::ConvolutionDimensions> convolutionDims = |
277 | linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current)); |
278 | if (failed(convolutionDims)) |
279 | return emitSilenceableError() << "could not infer convolution dimensions" ; |
280 | |
281 | MLIRContext *context = current->getContext(); |
282 | Builder builder(context); |
283 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
284 | return llvm::to_vector( |
285 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
286 | return builder.getI64IntegerAttr(value); |
287 | })); |
288 | }; |
289 | results.setParams(cast<OpResult>(getBatch()), |
290 | makeI64Attrs(convolutionDims->batch)); |
291 | results.setParams(cast<OpResult>(getOutputImage()), |
292 | makeI64Attrs(convolutionDims->outputImage)); |
293 | results.setParams(cast<OpResult>(getOutputChannel()), |
294 | makeI64Attrs(convolutionDims->outputChannel)); |
295 | results.setParams(cast<OpResult>(getFilterLoop()), |
296 | makeI64Attrs(convolutionDims->filterLoop)); |
297 | results.setParams(cast<OpResult>(getInputChannel()), |
298 | makeI64Attrs(convolutionDims->inputChannel)); |
299 | results.setParams(cast<OpResult>(getDepth()), |
300 | makeI64Attrs(convolutionDims->depth)); |
301 | |
302 | auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { |
303 | return llvm::to_vector( |
304 | llvm::map_range(values, [&](int64_t value) -> Attribute { |
305 | return builder.getI64IntegerAttr(value); |
306 | })); |
307 | }; |
308 | results.setParams(cast<OpResult>(getStrides()), |
309 | makeI64AttrsFromI64(convolutionDims->strides)); |
310 | results.setParams(cast<OpResult>(getDilations()), |
311 | makeI64AttrsFromI64(convolutionDims->dilations)); |
312 | return DiagnosedSilenceableFailure::success(); |
313 | } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // Utilities for structured match predicates. |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | /// Checks if all values from `list` are also contained in `reference`. Returns |
320 | /// a silenceable error with the given message at the given location when it is |
321 | /// not the case. The error message must contain the "{0}" placeholder that |
322 | /// will be substituted with the value from `list` that is not contained in |
323 | /// `reference`. |
324 | static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference, |
325 | ArrayRef<int64_t> list, |
326 | Location loc, |
327 | const char *message) { |
328 | for (int64_t value : list) { |
329 | if (llvm::any_of(Range&: reference, P: [&](unsigned ref) { |
330 | return static_cast<int64_t>(ref) == value; |
331 | })) { |
332 | continue; |
333 | } |
334 | return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value); |
335 | } |
336 | return DiagnosedSilenceableFailure::success(); |
337 | } |
338 | |
339 | //===----------------------------------------------------------------------===// |
340 | // MatchStructuredDimOp |
341 | //===----------------------------------------------------------------------===// |
342 | |
343 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( |
344 | Operation *current, transform::TransformResults &results, |
345 | transform::TransformState &state) { |
346 | auto linalgOp = cast<linalg::LinalgOp>(current); |
347 | SmallVector<int64_t> dimensions; |
348 | DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); |
349 | if (!diag.succeeded()) |
350 | return diag; |
351 | |
352 | // If asked to check for the kind of dimension, perform the check. |
353 | if (getParallel() || getReduction()) { |
354 | SmallVector<unsigned> reference; |
355 | if (getParallel()) |
356 | linalgOp.getParallelDims(reference); |
357 | else if (getReduction()) |
358 | linalgOp.getReductionDims(reference); |
359 | |
360 | DiagnosedSilenceableFailure diag = |
361 | containsAll(reference, dimensions, getLoc(), |
362 | getParallel() ? "expects dimension #{0} to be parallel" |
363 | : "expects dimension #{0} to be reduction" ); |
364 | if (!diag.succeeded()) |
365 | return diag; |
366 | } |
367 | |
368 | // If not capturing, we are done here. |
369 | if (!getResult()) |
370 | return diag; |
371 | |
372 | SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges(); |
373 | Builder builder(current); |
374 | SmallVector<Attribute> captured = llvm::to_vector( |
375 | llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { |
376 | return builder.getI64IntegerAttr(ranges[dim]); |
377 | })); |
378 | results.setParams(cast<OpResult>(getResult()), captured); |
379 | return DiagnosedSilenceableFailure::success(); |
380 | } |
381 | |
382 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( |
383 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) { |
384 | DiagnosedSilenceableFailure diag = |
385 | expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), |
386 | getRawDimList(), op.getNumLoops(), dims); |
387 | if (diag.isSilenceableFailure()) { |
388 | diag.attachNote(op->getLoc()) |
389 | << "while considering dimensions of this payload operation" ; |
390 | } |
391 | return diag; |
392 | } |
393 | |
394 | LogicalResult transform::MatchStructuredDimOp::verify() { |
395 | if (getParallel() && getReduction()) { |
396 | return emitOpError() << "cannot request the same dimension to be both " |
397 | "parallel and reduction" ; |
398 | } |
399 | return verifyTransformMatchDimsOp(getOperation(), getRawDimList(), |
400 | getIsInverted(), getIsAll()); |
401 | } |
402 | |
403 | //===----------------------------------------------------------------------===// |
404 | // MatchStructuredElementalBitwidthOp |
405 | //===----------------------------------------------------------------------===// |
406 | |
407 | DiagnosedSilenceableFailure |
408 | transform::MatchStructuredElementalBitwidthOp::matchValue( |
409 | Value current, transform::TransformResults &results, |
410 | transform::TransformState &state) { |
411 | auto setupResult = [&](int64_t bitwidth) { |
412 | Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); |
413 | results.setParams(cast<OpResult>(getResult()), {attr}); |
414 | return DiagnosedSilenceableFailure::success(); |
415 | }; |
416 | |
417 | Type type = current.getType(); |
418 | if (type.isIntOrFloat()) |
419 | return setupResult(type.getIntOrFloatBitWidth()); |
420 | |
421 | if (auto shapedType = dyn_cast<ShapedType>(type)) { |
422 | if (shapedType.getElementType().isIntOrFloat()) |
423 | return setupResult(shapedType.getElementTypeBitWidth()); |
424 | } |
425 | return emitSilenceableError() |
426 | << "unsupported type for bitwidth extraction: " << type; |
427 | } |
428 | |
429 | //===----------------------------------------------------------------------===// |
430 | // MatchStructuredInputOp |
431 | //===----------------------------------------------------------------------===// |
432 | |
433 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( |
434 | Operation *current, transform::TransformResults &results, |
435 | transform::TransformState &state) { |
436 | auto linalgOp = cast<linalg::LinalgOp>(current); |
437 | SmallVector<int64_t> positions; |
438 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
439 | if (!diag.succeeded()) |
440 | return diag; |
441 | |
442 | SmallVector<MappedValue> operandMapping; |
443 | operandMapping.reserve(positions.size()); |
444 | for (int64_t position : positions) { |
445 | AffineMap indexingMap = |
446 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); |
447 | if (getPermutation() && !indexingMap.isPermutation()) { |
448 | return emitSilenceableError() << "the indexing map for input #" |
449 | << position << " is not a permutation" ; |
450 | } |
451 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
452 | return emitSilenceableError() |
453 | << "the indexing map for input #" << position |
454 | << " is not a projected permutation" ; |
455 | } |
456 | |
457 | // If capture not requested, skip it. |
458 | if (!getResult()) |
459 | continue; |
460 | |
461 | if (isa<AffineMapParamType>(getResult().getType())) { |
462 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
463 | continue; |
464 | } |
465 | |
466 | Value operand = linalgOp.getDpsInputOperand(position)->get(); |
467 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
468 | operandMapping.emplace_back(operand); |
469 | continue; |
470 | } |
471 | |
472 | Operation *operandProducer = operand.getDefiningOp(); |
473 | if (!operandProducer) { |
474 | return emitSilenceableError() |
475 | << "input #" << position << " is not produced by an operation" ; |
476 | } |
477 | operandMapping.emplace_back(operandProducer); |
478 | } |
479 | if (getResult()) |
480 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
481 | return DiagnosedSilenceableFailure::success(); |
482 | } |
483 | |
484 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( |
485 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
486 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
487 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
488 | op.getNumDpsInputs(), positions); |
489 | if (diag.isSilenceableFailure()) { |
490 | diag.attachNote(op->getLoc()) |
491 | << "while considering DPS inputs of this payload operation" ; |
492 | } |
493 | return diag; |
494 | } |
495 | |
496 | /// Verifies a matcher op for structured input or output, specifically the |
497 | /// attributes specifying the operand positions. |
498 | template <typename OpTy> |
499 | LogicalResult verifyStructuredOperandOp(OpTy op) { |
500 | if (op.getPermutation() && op.getProjectedPermutation()) { |
501 | return op.emitOpError() |
502 | << op.getPermutationAttrName() << " and " |
503 | << op.getProjectedPermutationAttrName() << " are mutually exclusive" ; |
504 | } |
505 | if (op.getRawPositionList().size() > 1 && op.getResult()) { |
506 | return op.emitOpError() |
507 | << "cannot bind multiple inputs/inits to the same value" ; |
508 | } |
509 | |
510 | return success(); |
511 | } |
512 | |
513 | LogicalResult transform::MatchStructuredInputOp::verify() { |
514 | if (failed(verifyStructuredOperandOp(*this))) |
515 | return failure(); |
516 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
517 | getIsInverted(), getIsAll()); |
518 | } |
519 | |
520 | //===----------------------------------------------------------------------===// |
521 | // MatchStructuredInitOp |
522 | //===----------------------------------------------------------------------===// |
523 | |
524 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( |
525 | Operation *current, transform::TransformResults &results, |
526 | transform::TransformState &state) { |
527 | auto linalgOp = cast<linalg::LinalgOp>(current); |
528 | SmallVector<int64_t> positions; |
529 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
530 | if (!diag.succeeded()) |
531 | return diag; |
532 | |
533 | SmallVector<MappedValue> operandMapping; |
534 | operandMapping.reserve(positions.size()); |
535 | for (int64_t position : positions) { |
536 | AffineMap indexingMap = |
537 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); |
538 | if (getPermutation() && !indexingMap.isPermutation()) { |
539 | return emitSilenceableError() << "the indexing map for output(init) #" |
540 | << position << " is not a permutation" ; |
541 | } |
542 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
543 | return emitSilenceableError() << "the indexing map for output(init) #" |
544 | << position << " is not a permutation" ; |
545 | } |
546 | |
547 | // If capture not requested, skip it. |
548 | if (!getResult()) |
549 | continue; |
550 | |
551 | if (isa<AffineMapParamType>(getResult().getType())) { |
552 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
553 | continue; |
554 | } |
555 | |
556 | Value operand = linalgOp.getDpsInitOperand(position)->get(); |
557 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
558 | operandMapping.emplace_back(operand); |
559 | continue; |
560 | } |
561 | |
562 | Operation *operandProducer = operand.getDefiningOp(); |
563 | if (!operandProducer) { |
564 | return emitSilenceableError() << "output(init) #" << position |
565 | << " is not produced by an operation" ; |
566 | } |
567 | operandMapping.emplace_back(operandProducer); |
568 | } |
569 | if (getResult()) |
570 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
571 | return DiagnosedSilenceableFailure::success(); |
572 | } |
573 | |
574 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( |
575 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
576 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
577 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
578 | op.getNumDpsInits(), positions); |
579 | if (diag.isSilenceableFailure()) { |
580 | diag.attachNote(op->getLoc()) |
581 | << "while considering DPS inits (outputs) of this payload operation" ; |
582 | } |
583 | return diag; |
584 | } |
585 | |
586 | LogicalResult transform::MatchStructuredInitOp::verify() { |
587 | if (failed(verifyStructuredOperandOp(*this))) |
588 | return failure(); |
589 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
590 | getIsInverted(), getIsAll()); |
591 | } |
592 | |
593 | //===----------------------------------------------------------------------===// |
594 | // MatchStructuredNumInputsOp |
595 | //===----------------------------------------------------------------------===// |
596 | |
597 | DiagnosedSilenceableFailure |
598 | transform::MatchStructuredNumInputsOp::matchOperation( |
599 | Operation *current, transform::TransformResults &results, |
600 | transform::TransformState &state) { |
601 | auto linalgOp = cast<linalg::LinalgOp>(current); |
602 | Attribute attr = |
603 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); |
604 | results.setParams(cast<OpResult>(getResult()), {attr}); |
605 | return DiagnosedSilenceableFailure::success(); |
606 | } |
607 | |
608 | //===----------------------------------------------------------------------===// |
609 | // MatchStructuredNumInitsOp |
610 | //===----------------------------------------------------------------------===// |
611 | |
612 | DiagnosedSilenceableFailure |
613 | transform::MatchStructuredNumInitsOp::matchOperation( |
614 | Operation *current, transform::TransformResults &results, |
615 | transform::TransformState &state) { |
616 | auto linalgOp = cast<linalg::LinalgOp>(current); |
617 | Attribute attr = |
618 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); |
619 | results.setParams(cast<OpResult>(getResult()), {attr}); |
620 | return DiagnosedSilenceableFailure::success(); |
621 | } |
622 | |
623 | //===----------------------------------------------------------------------===// |
624 | // MatchStructuredRankOp |
625 | //===----------------------------------------------------------------------===// |
626 | |
627 | DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( |
628 | Operation *current, transform::TransformResults &results, |
629 | transform::TransformState &state) { |
630 | auto linalgOp = cast<linalg::LinalgOp>(current); |
631 | int64_t numLoops = linalgOp.getNumLoops(); |
632 | Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); |
633 | results.setParams(cast<OpResult>(getRank()), {attr}); |
634 | return DiagnosedSilenceableFailure::success(); |
635 | } |
636 | |
637 | //===----------------------------------------------------------------------===// |
638 | // MatchStructuredResultOp |
639 | //===----------------------------------------------------------------------===// |
640 | |
641 | DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( |
642 | Operation *op, transform::TransformResults &results, |
643 | transform::TransformState &state) { |
644 | auto linalgOp = cast<linalg::LinalgOp>(op); |
645 | int64_t position; |
646 | DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); |
647 | if (!diag.succeeded()) |
648 | return diag; |
649 | |
650 | Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); |
651 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
652 | results.setValues(cast<OpResult>(getResult()), {result}); |
653 | return DiagnosedSilenceableFailure::success(); |
654 | } |
655 | |
656 | if (result.getUsers().empty()) { |
657 | return emitSilenceableError() |
658 | << "no users of the result #" << getPosition(); |
659 | } |
660 | Operation *firstUser = *result.getUsers().begin(); |
661 | if (getAny()) { |
662 | results.set(cast<OpResult>(getResult()), {firstUser}); |
663 | return DiagnosedSilenceableFailure::success(); |
664 | } |
665 | if (getSingle()) { |
666 | if (!llvm::hasSingleElement(result.getUsers())) { |
667 | return emitSilenceableError() |
668 | << "more than one result user with single user requested" ; |
669 | } |
670 | results.set(cast<OpResult>(getResult()), {firstUser}); |
671 | return DiagnosedSilenceableFailure::success(); |
672 | } |
673 | |
674 | return emitDefiniteFailure() << "unknown sub-predicate" ; |
675 | } |
676 | |
677 | DiagnosedSilenceableFailure |
678 | transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, |
679 | int64_t &position) { |
680 | auto rawPosition = static_cast<int64_t>(getPosition()); |
681 | position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; |
682 | if (position >= op.getNumDpsInits() || position < 0) { |
683 | return emitSilenceableError() |
684 | << "position " << rawPosition |
685 | << " overflows the number of results(ints) of the payload operation" ; |
686 | } |
687 | return DiagnosedSilenceableFailure::success(); |
688 | } |
689 | |
690 | LogicalResult transform::MatchStructuredResultOp::verify() { |
691 | if ((getAny() || getSingle()) ^ |
692 | isa<TransformHandleTypeInterface>(getResult().getType())) { |
693 | return emitOpError() << "expects either the any/single keyword or the type " |
694 | "value handle result type" ; |
695 | } |
696 | if (getAny() && getSingle()) { |
697 | return emitOpError() << "'any' and 'single' are mutually exclusive" ; |
698 | } |
699 | return success(); |
700 | } |
701 | |
702 | //===----------------------------------------------------------------------===// |
703 | // MatchStructuredYieldOp |
704 | //===----------------------------------------------------------------------===// |
705 | |
706 | void transform::MatchStructuredYieldOp::getEffects( |
707 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
708 | onlyReadsHandle(getHandlesMutable(), effects); |
709 | onlyReadsPayload(effects); |
710 | } |
711 | |
712 | void transform::MatchStructuredYieldOp::build(OpBuilder &builder, |
713 | OperationState &state) { |
714 | build(builder, state, ValueRange()); |
715 | } |
716 | |
717 | #define GET_OP_CLASSES |
718 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" |
719 | |