1 | //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" |
10 | #include "mlir/Analysis/SliceAnalysis.h" |
11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
12 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
13 | #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
16 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/Interfaces/FunctionImplementation.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | |
22 | using namespace mlir; |
23 | |
24 | #define DEBUG_TYPE "linalg-transforms" |
25 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // StructuredMatchOp |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( |
32 | Operation *current, transform::TransformResults &results, |
33 | transform::TransformState &state) { |
34 | // First, check if the payload operation is a structured Linalg operation. |
35 | if (!isa<linalg::LinalgOp>(current)) { |
36 | if (getFailurePropagationMode().value_or( |
37 | FailurePropagationMode::Propagate) == |
38 | FailurePropagationMode::Propagate) { |
39 | return emitSilenceableError() << "expected a Linalg op" ; |
40 | } |
41 | // If errors are suppressed, succeed and set all results to empty lists. |
42 | LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op" ); |
43 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
44 | return DiagnosedSilenceableFailure::success(); |
45 | } |
46 | |
47 | // Bind `current` to the block argument. |
48 | auto scope = state.make_region_scope(getBodyRegion()); |
49 | if (failed(state.mapBlockArgument(getBody()->getArgument(0), |
50 | MappedValue(current)))) { |
51 | return DiagnosedSilenceableFailure::definiteFailure(); |
52 | } |
53 | |
54 | for (Operation &nested : getBody()->without_terminator()) { |
55 | DiagnosedSilenceableFailure diag = |
56 | state.applyTransform(cast<TransformOpInterface>(nested)); |
57 | if (diag.isDefiniteFailure()) |
58 | return diag; |
59 | if (diag.succeeded()) |
60 | continue; |
61 | |
62 | // If propagating errors, do this immediately. |
63 | assert(diag.isSilenceableFailure()); |
64 | if (getFailurePropagationMode().value_or( |
65 | FailurePropagationMode::Propagate) == |
66 | FailurePropagationMode::Propagate) { |
67 | return diag; |
68 | } |
69 | |
70 | // If suppressing errors, print the message into the debug stream before |
71 | // silencing it. Then set all results value that are already known. |
72 | // Results come from the terminator operands, which may be defined in the |
73 | // (single) block of this operation or above it. When they are defined |
74 | // above, they are known to be mapped at this point per SSA dominance. |
75 | // When they are defined in this block, we additionally check if we have |
76 | // already applied the operation that defines them. If not, the |
77 | // corresponding results will be set to empty lists. |
78 | LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() |
79 | << "\n" ); |
80 | (void)diag.silence(); |
81 | SmallVector<OpOperand *> undefinedOperands; |
82 | for (OpOperand &terminatorOperand : |
83 | getBody()->getTerminator()->getOpOperands()) { |
84 | Operation *definingOp = terminatorOperand.get().getDefiningOp(); |
85 | if (!definingOp) |
86 | continue; |
87 | if (definingOp->getBlock() != getBody()) |
88 | continue; |
89 | if (definingOp->isBeforeInBlock(&nested)) |
90 | continue; |
91 | |
92 | undefinedOperands.push_back(&terminatorOperand); |
93 | } |
94 | |
95 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
96 | auto filtered = llvm::make_filter_range( |
97 | getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { |
98 | return !llvm::is_contained(undefinedOperands, &opOperand); |
99 | }); |
100 | SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range( |
101 | filtered, [](OpOperand &opOperand) { return opOperand.get(); })); |
102 | detail::prepareValueMappings(mappings, definedOperands, state); |
103 | for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { |
104 | results.setMappedValues(getResults()[operand.getOperandNumber()], |
105 | mapping); |
106 | } |
107 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
108 | return DiagnosedSilenceableFailure::success(); |
109 | } |
110 | |
111 | // Set the results. |
112 | detail::forwardTerminatorOperands(getBody(), state, results); |
113 | return DiagnosedSilenceableFailure::success(); |
114 | } |
115 | |
116 | void transform::MatchStructuredOp::getEffects( |
117 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
118 | onlyReadsHandle(getCurrent(), effects); |
119 | onlyReadsPayload(effects); |
120 | producesHandle(getOutputs(), effects); |
121 | } |
122 | |
123 | LogicalResult transform::MatchStructuredOp::verify() { |
124 | if (getBody()->getNumArguments() != 1) |
125 | return emitOpError() << "expected one body argument" ; |
126 | if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) { |
127 | return emitOpError() << "expected body argument to implement " |
128 | "TransformHandleTypeInterface" ; |
129 | } |
130 | for (Operation &nested : getBody()->without_terminator()) { |
131 | if (isa<MatchOpInterface>(nested)) |
132 | continue; |
133 | InFlightDiagnostic diag = |
134 | emitOpError() |
135 | << "expects nested operations to implement MatchOpInterface" ; |
136 | diag.attachNote(nested.getLoc()) << "offending operation" ; |
137 | return diag; |
138 | } |
139 | return success(); |
140 | } |
141 | |
142 | //===----------------------------------------------------------------------===// |
143 | // StructuredOpPredicateOpTrait |
144 | //===----------------------------------------------------------------------===// |
145 | |
146 | LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( |
147 | Operation *op, Value structuredOpHandle) { |
148 | if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) { |
149 | return op->emitOpError() << "expects parent op to be '" |
150 | << MatchStructuredOp::getOperationName() << "'" ; |
151 | } |
152 | |
153 | // Bail out here, let the verifier of the parent complain. |
154 | Operation *parent = op->getParentOp(); |
155 | if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() || |
156 | parent->getRegion(index: 0).front().getNumArguments() < 1) |
157 | return success(); |
158 | |
159 | if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) { |
160 | return op->emitOpError() |
161 | << "expected predicate to apply to the surrounding structured op" ; |
162 | } |
163 | return success(); |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // MatchStructuredBodyOp |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( |
171 | Operation *current, transform::TransformResults &results, |
172 | transform::TransformState &state) { |
173 | auto linalgOp = cast<linalg::LinalgOp>(current); |
174 | if (std::optional<uint64_t> position = getReductionPosition()) { |
175 | SmallVector<Operation *> combinerOps; |
176 | if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, |
177 | combinerOps)) { |
178 | return emitSilenceableError() << "could not match reduction" ; |
179 | } |
180 | if (combinerOps.size() != 1) { |
181 | return emitSilenceableError() << "reduction combiner is not a single op" ; |
182 | } |
183 | return DiagnosedSilenceableFailure::success(); |
184 | } |
185 | if (getPassthrough()) { |
186 | Block &body = linalgOp->getRegion(0).front(); |
187 | if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { |
188 | return emitSilenceableError() << "not a passthrough" ; |
189 | } |
190 | return DiagnosedSilenceableFailure::success(); |
191 | } |
192 | if (getElementwise()) { |
193 | if (!isElementwise(linalgOp)) |
194 | return emitSilenceableError() << "not elementwise" ; |
195 | return DiagnosedSilenceableFailure::success(); |
196 | } |
197 | if (std::optional<ArrayAttr> contractionOps = getContraction()) { |
198 | Block &body = linalgOp->getRegion(0).front(); |
199 | std::string message; |
200 | llvm::raw_string_ostream os(message); |
201 | bool result = linalg::detail::isContractionBody( |
202 | body, |
203 | [&](Operation *elem, Operation *red) { |
204 | return elem->getName().getStringRef() == |
205 | cast<StringAttr>((*contractionOps)[0]).getValue() && |
206 | red->getName().getStringRef() == |
207 | cast<StringAttr>((*contractionOps)[1]).getValue(); |
208 | }, |
209 | os); |
210 | if (result) |
211 | return DiagnosedSilenceableFailure::success(); |
212 | return emitSilenceableError() << "contraction: " << os.str(); |
213 | } |
214 | return emitDefiniteFailure() << "unknown body condition" ; |
215 | } |
216 | |
217 | LogicalResult transform::MatchStructuredBodyOp::verify() { |
218 | int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + |
219 | getElementwise() + getContraction().has_value(); |
220 | |
221 | if (numOptions > 1) { |
222 | std::string attributeNames; |
223 | llvm::raw_string_ostream os(attributeNames); |
224 | llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(), |
225 | getPassthroughAttrName(), |
226 | getElementwiseAttrName(), |
227 | getContractionAttrName()}, |
228 | os); |
229 | return emitOpError() << "only one of {" << os.str() << "} is allowed" ; |
230 | } |
231 | |
232 | if (std::optional<ArrayAttr> contractionAttr = getContraction()) { |
233 | if (contractionAttr->size() != 2) { |
234 | return emitOpError() << "expects " << getContractionAttrName() |
235 | << " to contain two elements" ; |
236 | } |
237 | } |
238 | return success(); |
239 | } |
240 | |
241 | //===----------------------------------------------------------------------===// |
242 | // MatchStructuredClassifyContractionDimsOp |
243 | //===----------------------------------------------------------------------===// |
244 | |
245 | DiagnosedSilenceableFailure |
246 | transform::MatchStructuredClassifyContractionDimsOp::matchOperation( |
247 | Operation *current, transform::TransformResults &results, |
248 | transform::TransformState &state) { |
249 | FailureOr<linalg::ContractionDimensions> contractionDims = |
250 | linalg::inferContractionDims(cast<linalg::LinalgOp>(current)); |
251 | if (failed(contractionDims)) |
252 | return emitSilenceableError() << "could not infer contraction dimensions" ; |
253 | |
254 | MLIRContext *context = current->getContext(); |
255 | Builder builder(context); |
256 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
257 | return llvm::to_vector( |
258 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
259 | return builder.getI64IntegerAttr(value); |
260 | })); |
261 | }; |
262 | results.setParams(cast<OpResult>(getBatch()), |
263 | makeI64Attrs(contractionDims->batch)); |
264 | results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m)); |
265 | results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n)); |
266 | results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k)); |
267 | return DiagnosedSilenceableFailure::success(); |
268 | } |
269 | |
270 | //===----------------------------------------------------------------------===// |
271 | // MatchStructuredClassifyConvolutionDimsOp |
272 | //===----------------------------------------------------------------------===// |
273 | |
274 | DiagnosedSilenceableFailure |
275 | transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( |
276 | Operation *current, transform::TransformResults &results, |
277 | transform::TransformState &state) { |
278 | FailureOr<linalg::ConvolutionDimensions> convolutionDims = |
279 | linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current)); |
280 | if (failed(convolutionDims)) |
281 | return emitSilenceableError() << "could not infer convolution dimensions" ; |
282 | |
283 | MLIRContext *context = current->getContext(); |
284 | Builder builder(context); |
285 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
286 | return llvm::to_vector( |
287 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
288 | return builder.getI64IntegerAttr(value); |
289 | })); |
290 | }; |
291 | results.setParams(cast<OpResult>(getBatch()), |
292 | makeI64Attrs(convolutionDims->batch)); |
293 | results.setParams(cast<OpResult>(getOutputImage()), |
294 | makeI64Attrs(convolutionDims->outputImage)); |
295 | results.setParams(cast<OpResult>(getOutputChannel()), |
296 | makeI64Attrs(convolutionDims->outputChannel)); |
297 | results.setParams(cast<OpResult>(getFilterLoop()), |
298 | makeI64Attrs(convolutionDims->filterLoop)); |
299 | results.setParams(cast<OpResult>(getInputChannel()), |
300 | makeI64Attrs(convolutionDims->inputChannel)); |
301 | results.setParams(cast<OpResult>(getDepth()), |
302 | makeI64Attrs(convolutionDims->depth)); |
303 | |
304 | auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { |
305 | return llvm::to_vector( |
306 | llvm::map_range(values, [&](int64_t value) -> Attribute { |
307 | return builder.getI64IntegerAttr(value); |
308 | })); |
309 | }; |
310 | results.setParams(cast<OpResult>(getStrides()), |
311 | makeI64AttrsFromI64(convolutionDims->strides)); |
312 | results.setParams(cast<OpResult>(getDilations()), |
313 | makeI64AttrsFromI64(convolutionDims->dilations)); |
314 | return DiagnosedSilenceableFailure::success(); |
315 | } |
316 | |
317 | //===----------------------------------------------------------------------===// |
318 | // Utilities for structured match predicates. |
319 | //===----------------------------------------------------------------------===// |
320 | |
321 | /// Checks if all values from `list` are also contained in `reference`. Returns |
322 | /// a silenceable error with the given message at the given location when it is |
323 | /// not the case. The error message must contain the "{0}" placeholder that |
324 | /// will be substituted with the value from `list` that is not contained in |
325 | /// `reference`. |
326 | static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference, |
327 | ArrayRef<int64_t> list, |
328 | Location loc, |
329 | const char *message) { |
330 | for (int64_t value : list) { |
331 | if (llvm::any_of(Range&: reference, P: [&](unsigned ref) { |
332 | return static_cast<int64_t>(ref) == value; |
333 | })) { |
334 | continue; |
335 | } |
336 | return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value); |
337 | } |
338 | return DiagnosedSilenceableFailure::success(); |
339 | } |
340 | |
341 | //===----------------------------------------------------------------------===// |
342 | // MatchStructuredDimOp |
343 | //===----------------------------------------------------------------------===// |
344 | |
345 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( |
346 | Operation *current, transform::TransformResults &results, |
347 | transform::TransformState &state) { |
348 | auto linalgOp = cast<linalg::LinalgOp>(current); |
349 | SmallVector<int64_t> dimensions; |
350 | DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); |
351 | if (!diag.succeeded()) |
352 | return diag; |
353 | |
354 | // If asked to check for the kind of dimension, perform the check. |
355 | if (getParallel() || getReduction()) { |
356 | SmallVector<unsigned> reference; |
357 | if (getParallel()) |
358 | linalgOp.getParallelDims(reference); |
359 | else if (getReduction()) |
360 | linalgOp.getReductionDims(reference); |
361 | |
362 | DiagnosedSilenceableFailure diag = |
363 | containsAll(reference, dimensions, getLoc(), |
364 | getParallel() ? "expects dimension #{0} to be parallel" |
365 | : "expects dimension #{0} to be reduction" ); |
366 | if (!diag.succeeded()) |
367 | return diag; |
368 | } |
369 | |
370 | // If not capturing, we are done here. |
371 | if (!getResult()) |
372 | return diag; |
373 | |
374 | SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges(); |
375 | Builder builder(current); |
376 | SmallVector<Attribute> captured = llvm::to_vector( |
377 | llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { |
378 | return builder.getI64IntegerAttr(ranges[dim]); |
379 | })); |
380 | results.setParams(cast<OpResult>(getResult()), captured); |
381 | return DiagnosedSilenceableFailure::success(); |
382 | } |
383 | |
384 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( |
385 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) { |
386 | DiagnosedSilenceableFailure diag = |
387 | expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), |
388 | getRawDimList(), op.getNumLoops(), dims); |
389 | if (diag.isSilenceableFailure()) { |
390 | diag.attachNote(op->getLoc()) |
391 | << "while considering dimensions of this payload operation" ; |
392 | } |
393 | return diag; |
394 | } |
395 | |
396 | LogicalResult transform::MatchStructuredDimOp::verify() { |
397 | if (getParallel() && getReduction()) { |
398 | return emitOpError() << "cannot request the same dimension to be both " |
399 | "parallel and reduction" ; |
400 | } |
401 | return verifyTransformMatchDimsOp(getOperation(), getRawDimList(), |
402 | getIsInverted(), getIsAll()); |
403 | } |
404 | |
405 | //===----------------------------------------------------------------------===// |
406 | // MatchStructuredElementalBitwidthOp |
407 | //===----------------------------------------------------------------------===// |
408 | |
409 | DiagnosedSilenceableFailure |
410 | transform::MatchStructuredElementalBitwidthOp::matchValue( |
411 | Value current, transform::TransformResults &results, |
412 | transform::TransformState &state) { |
413 | auto setupResult = [&](int64_t bitwidth) { |
414 | Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); |
415 | results.setParams(cast<OpResult>(getResult()), {attr}); |
416 | return DiagnosedSilenceableFailure::success(); |
417 | }; |
418 | |
419 | Type type = current.getType(); |
420 | if (type.isIntOrFloat()) |
421 | return setupResult(type.getIntOrFloatBitWidth()); |
422 | |
423 | if (auto shapedType = dyn_cast<ShapedType>(type)) { |
424 | if (shapedType.getElementType().isIntOrFloat()) |
425 | return setupResult(shapedType.getElementTypeBitWidth()); |
426 | } |
427 | return emitSilenceableError() |
428 | << "unsupported type for bitwidth extraction: " << type; |
429 | } |
430 | |
431 | //===----------------------------------------------------------------------===// |
432 | // MatchStructuredInputOp |
433 | //===----------------------------------------------------------------------===// |
434 | |
435 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( |
436 | Operation *current, transform::TransformResults &results, |
437 | transform::TransformState &state) { |
438 | auto linalgOp = cast<linalg::LinalgOp>(current); |
439 | SmallVector<int64_t> positions; |
440 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
441 | if (!diag.succeeded()) |
442 | return diag; |
443 | |
444 | SmallVector<MappedValue> operandMapping; |
445 | operandMapping.reserve(positions.size()); |
446 | for (int64_t position : positions) { |
447 | AffineMap indexingMap = |
448 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); |
449 | if (getPermutation() && !indexingMap.isPermutation()) { |
450 | return emitSilenceableError() << "the indexing map for input #" |
451 | << position << " is not a permutation" ; |
452 | } |
453 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
454 | return emitSilenceableError() |
455 | << "the indexing map for input #" << position |
456 | << " is not a projected permutation" ; |
457 | } |
458 | |
459 | // If capture not requested, skip it. |
460 | if (!getResult()) |
461 | continue; |
462 | |
463 | if (isa<AffineMapParamType>(getResult().getType())) { |
464 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
465 | continue; |
466 | } |
467 | |
468 | Value operand = linalgOp.getDpsInputOperand(position)->get(); |
469 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
470 | operandMapping.emplace_back(operand); |
471 | continue; |
472 | } |
473 | |
474 | Operation *operandProducer = operand.getDefiningOp(); |
475 | if (!operandProducer) { |
476 | return emitSilenceableError() |
477 | << "input #" << position << " is not produced by an operation" ; |
478 | } |
479 | operandMapping.emplace_back(operandProducer); |
480 | } |
481 | if (getResult()) |
482 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
483 | return DiagnosedSilenceableFailure::success(); |
484 | } |
485 | |
486 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( |
487 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
488 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
489 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
490 | op.getNumDpsInputs(), positions); |
491 | if (diag.isSilenceableFailure()) { |
492 | diag.attachNote(op->getLoc()) |
493 | << "while considering DPS inputs of this payload operation" ; |
494 | } |
495 | return diag; |
496 | } |
497 | |
498 | /// Verifies a matcher op for structured input or output, specifically the |
499 | /// attributes specifying the operand positions. |
500 | template <typename OpTy> |
501 | LogicalResult verifyStructuredOperandOp(OpTy op) { |
502 | if (op.getPermutation() && op.getProjectedPermutation()) { |
503 | return op.emitOpError() |
504 | << op.getPermutationAttrName() << " and " |
505 | << op.getProjectedPermutationAttrName() << " are mutually exclusive" ; |
506 | } |
507 | if (op.getRawPositionList().size() > 1 && op.getResult()) { |
508 | return op.emitOpError() |
509 | << "cannot bind multiple inputs/inits to the same value" ; |
510 | } |
511 | |
512 | return success(); |
513 | } |
514 | |
515 | LogicalResult transform::MatchStructuredInputOp::verify() { |
516 | if (failed(verifyStructuredOperandOp(*this))) |
517 | return failure(); |
518 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
519 | getIsInverted(), getIsAll()); |
520 | } |
521 | |
522 | //===----------------------------------------------------------------------===// |
523 | // MatchStructuredInitOp |
524 | //===----------------------------------------------------------------------===// |
525 | |
526 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( |
527 | Operation *current, transform::TransformResults &results, |
528 | transform::TransformState &state) { |
529 | auto linalgOp = cast<linalg::LinalgOp>(current); |
530 | SmallVector<int64_t> positions; |
531 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
532 | if (!diag.succeeded()) |
533 | return diag; |
534 | |
535 | SmallVector<MappedValue> operandMapping; |
536 | operandMapping.reserve(positions.size()); |
537 | for (int64_t position : positions) { |
538 | AffineMap indexingMap = |
539 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); |
540 | if (getPermutation() && !indexingMap.isPermutation()) { |
541 | return emitSilenceableError() << "the indexing map for output(init) #" |
542 | << position << " is not a permutation" ; |
543 | } |
544 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
545 | return emitSilenceableError() << "the indexing map for output(init) #" |
546 | << position << " is not a permutation" ; |
547 | } |
548 | |
549 | // If capture not requested, skip it. |
550 | if (!getResult()) |
551 | continue; |
552 | |
553 | if (isa<AffineMapParamType>(getResult().getType())) { |
554 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
555 | continue; |
556 | } |
557 | |
558 | Value operand = linalgOp.getDpsInitOperand(position)->get(); |
559 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
560 | operandMapping.emplace_back(operand); |
561 | continue; |
562 | } |
563 | |
564 | Operation *operandProducer = operand.getDefiningOp(); |
565 | if (!operandProducer) { |
566 | return emitSilenceableError() << "output(init) #" << position |
567 | << " is not produced by an operation" ; |
568 | } |
569 | operandMapping.emplace_back(operandProducer); |
570 | } |
571 | if (getResult()) |
572 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
573 | return DiagnosedSilenceableFailure::success(); |
574 | } |
575 | |
576 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( |
577 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
578 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
579 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
580 | op.getNumDpsInits(), positions); |
581 | if (diag.isSilenceableFailure()) { |
582 | diag.attachNote(op->getLoc()) |
583 | << "while considering DPS inits (outputs) of this payload operation" ; |
584 | } |
585 | return diag; |
586 | } |
587 | |
588 | LogicalResult transform::MatchStructuredInitOp::verify() { |
589 | if (failed(verifyStructuredOperandOp(*this))) |
590 | return failure(); |
591 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
592 | getIsInverted(), getIsAll()); |
593 | } |
594 | |
595 | //===----------------------------------------------------------------------===// |
596 | // MatchStructuredNumInputsOp |
597 | //===----------------------------------------------------------------------===// |
598 | |
599 | DiagnosedSilenceableFailure |
600 | transform::MatchStructuredNumInputsOp::matchOperation( |
601 | Operation *current, transform::TransformResults &results, |
602 | transform::TransformState &state) { |
603 | auto linalgOp = cast<linalg::LinalgOp>(current); |
604 | Attribute attr = |
605 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); |
606 | results.setParams(cast<OpResult>(getResult()), {attr}); |
607 | return DiagnosedSilenceableFailure::success(); |
608 | } |
609 | |
610 | //===----------------------------------------------------------------------===// |
611 | // MatchStructuredNumInitsOp |
612 | //===----------------------------------------------------------------------===// |
613 | |
614 | DiagnosedSilenceableFailure |
615 | transform::MatchStructuredNumInitsOp::matchOperation( |
616 | Operation *current, transform::TransformResults &results, |
617 | transform::TransformState &state) { |
618 | auto linalgOp = cast<linalg::LinalgOp>(current); |
619 | Attribute attr = |
620 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); |
621 | results.setParams(cast<OpResult>(getResult()), {attr}); |
622 | return DiagnosedSilenceableFailure::success(); |
623 | } |
624 | |
625 | //===----------------------------------------------------------------------===// |
626 | // MatchStructuredRankOp |
627 | //===----------------------------------------------------------------------===// |
628 | |
629 | DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( |
630 | Operation *current, transform::TransformResults &results, |
631 | transform::TransformState &state) { |
632 | auto linalgOp = cast<linalg::LinalgOp>(current); |
633 | int64_t numLoops = linalgOp.getNumLoops(); |
634 | Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); |
635 | results.setParams(cast<OpResult>(getRank()), {attr}); |
636 | return DiagnosedSilenceableFailure::success(); |
637 | } |
638 | |
639 | //===----------------------------------------------------------------------===// |
640 | // MatchStructuredResultOp |
641 | //===----------------------------------------------------------------------===// |
642 | |
643 | DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( |
644 | Operation *op, transform::TransformResults &results, |
645 | transform::TransformState &state) { |
646 | auto linalgOp = cast<linalg::LinalgOp>(op); |
647 | int64_t position; |
648 | DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); |
649 | if (!diag.succeeded()) |
650 | return diag; |
651 | |
652 | Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); |
653 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
654 | results.setValues(cast<OpResult>(getResult()), {result}); |
655 | return DiagnosedSilenceableFailure::success(); |
656 | } |
657 | |
658 | if (result.getUsers().empty()) { |
659 | return emitSilenceableError() |
660 | << "no users of the result #" << getPosition(); |
661 | } |
662 | Operation *firstUser = *result.getUsers().begin(); |
663 | if (getAny()) { |
664 | results.set(cast<OpResult>(getResult()), {firstUser}); |
665 | return DiagnosedSilenceableFailure::success(); |
666 | } |
667 | if (getSingle()) { |
668 | if (!llvm::hasSingleElement(result.getUsers())) { |
669 | return emitSilenceableError() |
670 | << "more than one result user with single user requested" ; |
671 | } |
672 | results.set(cast<OpResult>(getResult()), {firstUser}); |
673 | return DiagnosedSilenceableFailure::success(); |
674 | } |
675 | |
676 | return emitDefiniteFailure() << "unknown sub-predicate" ; |
677 | } |
678 | |
679 | DiagnosedSilenceableFailure |
680 | transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, |
681 | int64_t &position) { |
682 | auto rawPosition = static_cast<int64_t>(getPosition()); |
683 | position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; |
684 | if (position >= op.getNumDpsInits() || position < 0) { |
685 | return emitSilenceableError() |
686 | << "position " << rawPosition |
687 | << " overflows the number of results(ints) of the payload operation" ; |
688 | } |
689 | return DiagnosedSilenceableFailure::success(); |
690 | } |
691 | |
692 | LogicalResult transform::MatchStructuredResultOp::verify() { |
693 | if ((getAny() || getSingle()) ^ |
694 | isa<TransformHandleTypeInterface>(getResult().getType())) { |
695 | return emitOpError() << "expects either the any/single keyword or the type " |
696 | "value handle result type" ; |
697 | } |
698 | if (getAny() && getSingle()) { |
699 | return emitOpError() << "'any' and 'single' are mutually exclusive" ; |
700 | } |
701 | return success(); |
702 | } |
703 | |
704 | //===----------------------------------------------------------------------===// |
705 | // MatchStructuredYieldOp |
706 | //===----------------------------------------------------------------------===// |
707 | |
708 | void transform::MatchStructuredYieldOp::getEffects( |
709 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
710 | onlyReadsHandle(getHandles(), effects); |
711 | onlyReadsPayload(effects); |
712 | } |
713 | |
714 | void transform::MatchStructuredYieldOp::build(OpBuilder &builder, |
715 | OperationState &state) { |
716 | build(builder, state, ValueRange()); |
717 | } |
718 | |
719 | #define GET_OP_CLASSES |
720 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" |
721 | |