| 1 | //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" |
| 10 | #include "mlir/Analysis/SliceAnalysis.h" |
| 11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 12 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
| 13 | #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" |
| 14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
| 16 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
| 17 | #include "mlir/IR/BuiltinAttributes.h" |
| 18 | #include "llvm/Support/Debug.h" |
| 19 | #include "llvm/Support/FormatVariadic.h" |
| 20 | #include "llvm/Support/InterleavedRange.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | #define DEBUG_TYPE "linalg-transforms" |
| 25 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 26 | |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | // StructuredMatchOp |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | |
| 31 | DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( |
| 32 | Operation *current, transform::TransformResults &results, |
| 33 | transform::TransformState &state) { |
| 34 | // First, check if the payload operation is a structured Linalg operation. |
| 35 | if (!isa<linalg::LinalgOp>(Val: current)) { |
| 36 | if (getFailurePropagationMode().value_or( |
| 37 | u: FailurePropagationMode::Propagate) == |
| 38 | FailurePropagationMode::Propagate) { |
| 39 | return emitSilenceableError() << "expected a Linalg op" ; |
| 40 | } |
| 41 | // If errors are suppressed, succeed and set all results to empty lists. |
| 42 | LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op" ); |
| 43 | results.setRemainingToEmpty(cast<TransformOpInterface>(Val: getOperation())); |
| 44 | return DiagnosedSilenceableFailure::success(); |
| 45 | } |
| 46 | |
| 47 | // Bind `current` to the block argument. |
| 48 | auto scope = state.make_region_scope(region&: getBodyRegion()); |
| 49 | if (failed(Result: state.mapBlockArgument(argument: getBody()->getArgument(i: 0), |
| 50 | values: MappedValue(current)))) { |
| 51 | return DiagnosedSilenceableFailure::definiteFailure(); |
| 52 | } |
| 53 | |
| 54 | for (Operation &nested : getBody()->without_terminator()) { |
| 55 | DiagnosedSilenceableFailure diag = |
| 56 | state.applyTransform(transform: cast<TransformOpInterface>(Val&: nested)); |
| 57 | if (diag.isDefiniteFailure()) |
| 58 | return diag; |
| 59 | if (diag.succeeded()) |
| 60 | continue; |
| 61 | |
| 62 | // If propagating errors, do this immediately. |
| 63 | assert(diag.isSilenceableFailure()); |
| 64 | if (getFailurePropagationMode().value_or( |
| 65 | u: FailurePropagationMode::Propagate) == |
| 66 | FailurePropagationMode::Propagate) { |
| 67 | return diag; |
| 68 | } |
| 69 | |
| 70 | // If suppressing errors, print the message into the debug stream before |
| 71 | // silencing it. Then set all results value that are already known. |
| 72 | // Results come from the terminator operands, which may be defined in the |
| 73 | // (single) block of this operation or above it. When they are defined |
| 74 | // above, they are known to be mapped at this point per SSA dominance. |
| 75 | // When they are defined in this block, we additionally check if we have |
| 76 | // already applied the operation that defines them. If not, the |
| 77 | // corresponding results will be set to empty lists. |
| 78 | LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() |
| 79 | << "\n" ); |
| 80 | (void)diag.silence(); |
| 81 | SmallVector<OpOperand *> undefinedOperands; |
| 82 | for (OpOperand &terminatorOperand : |
| 83 | getBody()->getTerminator()->getOpOperands()) { |
| 84 | Operation *definingOp = terminatorOperand.get().getDefiningOp(); |
| 85 | if (!definingOp) |
| 86 | continue; |
| 87 | if (definingOp->getBlock() != getBody()) |
| 88 | continue; |
| 89 | if (definingOp->isBeforeInBlock(other: &nested)) |
| 90 | continue; |
| 91 | |
| 92 | undefinedOperands.push_back(Elt: &terminatorOperand); |
| 93 | } |
| 94 | |
| 95 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
| 96 | auto filtered = llvm::make_filter_range( |
| 97 | Range: getBody()->getTerminator()->getOpOperands(), Pred: [&](OpOperand &opOperand) { |
| 98 | return !llvm::is_contained(Range&: undefinedOperands, Element: &opOperand); |
| 99 | }); |
| 100 | SmallVector<Value> definedOperands = llvm::to_vector(Range: llvm::map_range( |
| 101 | C&: filtered, F: [](OpOperand &opOperand) { return opOperand.get(); })); |
| 102 | detail::prepareValueMappings(mappings, values: definedOperands, state); |
| 103 | for (auto &&[operand, mapping] : llvm::zip_equal(t&: filtered, u&: mappings)) { |
| 104 | results.setMappedValues(handle: getResults()[operand.getOperandNumber()], |
| 105 | values: mapping); |
| 106 | } |
| 107 | results.setRemainingToEmpty(cast<TransformOpInterface>(Val: getOperation())); |
| 108 | return DiagnosedSilenceableFailure::success(); |
| 109 | } |
| 110 | |
| 111 | // Set the results. |
| 112 | detail::forwardTerminatorOperands(block: getBody(), state, results); |
| 113 | return DiagnosedSilenceableFailure::success(); |
| 114 | } |
| 115 | |
| 116 | void transform::MatchStructuredOp::getEffects( |
| 117 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 118 | onlyReadsHandle(handles: getCurrentMutable(), effects); |
| 119 | onlyReadsPayload(effects); |
| 120 | producesHandle(handles: getOperation()->getOpResults(), effects); |
| 121 | } |
| 122 | |
| 123 | LogicalResult transform::MatchStructuredOp::verify() { |
| 124 | if (getBody()->getNumArguments() != 1) |
| 125 | return emitOpError() << "expected one body argument" ; |
| 126 | if (!isa<TransformHandleTypeInterface>(Val: getBody()->getArgument(i: 0).getType())) { |
| 127 | return emitOpError() << "expected body argument to implement " |
| 128 | "TransformHandleTypeInterface" ; |
| 129 | } |
| 130 | for (Operation &nested : getBody()->without_terminator()) { |
| 131 | if (isa<MatchOpInterface>(Val: nested)) |
| 132 | continue; |
| 133 | InFlightDiagnostic diag = |
| 134 | emitOpError() |
| 135 | << "expects nested operations to implement MatchOpInterface" ; |
| 136 | diag.attachNote(noteLoc: nested.getLoc()) << "offending operation" ; |
| 137 | return diag; |
| 138 | } |
| 139 | return success(); |
| 140 | } |
| 141 | |
| 142 | //===----------------------------------------------------------------------===// |
| 143 | // StructuredOpPredicateOpTrait |
| 144 | //===----------------------------------------------------------------------===// |
| 145 | |
| 146 | LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( |
| 147 | Operation *op, Value structuredOpHandle) { |
| 148 | if (!isa_and_nonnull<MatchStructuredOp>(Val: op->getParentOp())) { |
| 149 | return op->emitOpError() << "expects parent op to be '" |
| 150 | << MatchStructuredOp::getOperationName() << "'" ; |
| 151 | } |
| 152 | |
| 153 | // Bail out here, let the verifier of the parent complain. |
| 154 | Operation *parent = op->getParentOp(); |
| 155 | if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() || |
| 156 | parent->getRegion(index: 0).front().getNumArguments() < 1) |
| 157 | return success(); |
| 158 | |
| 159 | if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) { |
| 160 | return op->emitOpError() |
| 161 | << "expected predicate to apply to the surrounding structured op" ; |
| 162 | } |
| 163 | return success(); |
| 164 | } |
| 165 | |
| 166 | //===----------------------------------------------------------------------===// |
| 167 | // MatchStructuredBodyOp |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | |
| 170 | DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( |
| 171 | Operation *current, transform::TransformResults &results, |
| 172 | transform::TransformState &state) { |
| 173 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 174 | if (std::optional<uint64_t> position = getReductionPosition()) { |
| 175 | SmallVector<Operation *> combinerOps; |
| 176 | if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: *position, |
| 177 | combinerOps)) { |
| 178 | return emitSilenceableError() << "could not match reduction" ; |
| 179 | } |
| 180 | if (combinerOps.size() != 1) { |
| 181 | return emitSilenceableError() << "reduction combiner is not a single op" ; |
| 182 | } |
| 183 | return DiagnosedSilenceableFailure::success(); |
| 184 | } |
| 185 | if (getPassthrough()) { |
| 186 | Block &body = linalgOp->getRegion(index: 0).front(); |
| 187 | if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { |
| 188 | return emitSilenceableError() << "not a passthrough" ; |
| 189 | } |
| 190 | return DiagnosedSilenceableFailure::success(); |
| 191 | } |
| 192 | if (getElementwise()) { |
| 193 | if (!isElementwise(op: linalgOp)) |
| 194 | return emitSilenceableError() << "not elementwise" ; |
| 195 | return DiagnosedSilenceableFailure::success(); |
| 196 | } |
| 197 | if (std::optional<ArrayAttr> contractionOps = getContraction()) { |
| 198 | Block &body = linalgOp->getRegion(index: 0).front(); |
| 199 | std::string message; |
| 200 | llvm::raw_string_ostream os(message); |
| 201 | bool result = linalg::detail::isContractionBody( |
| 202 | block&: body, |
| 203 | isaPair: [&](Operation *elem, Operation *red) { |
| 204 | return elem->getName().getStringRef() == |
| 205 | cast<StringAttr>(Val: (*contractionOps)[0]).getValue() && |
| 206 | red->getName().getStringRef() == |
| 207 | cast<StringAttr>(Val: (*contractionOps)[1]).getValue(); |
| 208 | }, |
| 209 | errs&: os); |
| 210 | if (result) |
| 211 | return DiagnosedSilenceableFailure::success(); |
| 212 | return emitSilenceableError() << "contraction: " << message; |
| 213 | } |
| 214 | return emitDefiniteFailure() << "unknown body condition" ; |
| 215 | } |
| 216 | |
| 217 | LogicalResult transform::MatchStructuredBodyOp::verify() { |
| 218 | int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + |
| 219 | getElementwise() + getContraction().has_value(); |
| 220 | |
| 221 | if (numOptions > 1) { |
| 222 | StringAttr attributeNames[] = { |
| 223 | getReductionPositionAttrName(), getPassthroughAttrName(), |
| 224 | getElementwiseAttrName(), getContractionAttrName()}; |
| 225 | return emitOpError() << "only one of {" << llvm::interleaved(R: attributeNames) |
| 226 | << "} is allowed" ; |
| 227 | } |
| 228 | |
| 229 | if (std::optional<ArrayAttr> contractionAttr = getContraction()) { |
| 230 | if (contractionAttr->size() != 2) { |
| 231 | return emitOpError() << "expects " << getContractionAttrName() |
| 232 | << " to contain two elements" ; |
| 233 | } |
| 234 | } |
| 235 | return success(); |
| 236 | } |
| 237 | |
| 238 | //===----------------------------------------------------------------------===// |
| 239 | // MatchStructuredClassifyContractionDimsOp |
| 240 | //===----------------------------------------------------------------------===// |
| 241 | |
| 242 | DiagnosedSilenceableFailure |
| 243 | transform::MatchStructuredClassifyContractionDimsOp::matchOperation( |
| 244 | Operation *current, transform::TransformResults &results, |
| 245 | transform::TransformState &state) { |
| 246 | FailureOr<linalg::ContractionDimensions> contractionDims = |
| 247 | linalg::inferContractionDims(linalgOp: cast<linalg::LinalgOp>(Val: current)); |
| 248 | if (failed(Result: contractionDims)) |
| 249 | return emitSilenceableError() << "could not infer contraction dimensions" ; |
| 250 | |
| 251 | MLIRContext *context = current->getContext(); |
| 252 | Builder builder(context); |
| 253 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
| 254 | return llvm::to_vector( |
| 255 | Range: llvm::map_range(C&: values, F: [&](unsigned value) -> Attribute { |
| 256 | return builder.getI64IntegerAttr(value); |
| 257 | })); |
| 258 | }; |
| 259 | results.setParams(value: cast<OpResult>(Val: getBatch()), |
| 260 | params: makeI64Attrs(contractionDims->batch)); |
| 261 | results.setParams(value: cast<OpResult>(Val: getM()), params: makeI64Attrs(contractionDims->m)); |
| 262 | results.setParams(value: cast<OpResult>(Val: getN()), params: makeI64Attrs(contractionDims->n)); |
| 263 | results.setParams(value: cast<OpResult>(Val: getK()), params: makeI64Attrs(contractionDims->k)); |
| 264 | return DiagnosedSilenceableFailure::success(); |
| 265 | } |
| 266 | |
| 267 | //===----------------------------------------------------------------------===// |
| 268 | // MatchStructuredClassifyConvolutionDimsOp |
| 269 | //===----------------------------------------------------------------------===// |
| 270 | |
| 271 | DiagnosedSilenceableFailure |
| 272 | transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( |
| 273 | Operation *current, transform::TransformResults &results, |
| 274 | transform::TransformState &state) { |
| 275 | FailureOr<linalg::ConvolutionDimensions> convolutionDims = |
| 276 | linalg::inferConvolutionDims(linalgOp: cast<linalg::LinalgOp>(Val: current)); |
| 277 | if (failed(Result: convolutionDims)) |
| 278 | return emitSilenceableError() << "could not infer convolution dimensions" ; |
| 279 | |
| 280 | MLIRContext *context = current->getContext(); |
| 281 | Builder builder(context); |
| 282 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
| 283 | return llvm::to_vector( |
| 284 | Range: llvm::map_range(C&: values, F: [&](unsigned value) -> Attribute { |
| 285 | return builder.getI64IntegerAttr(value); |
| 286 | })); |
| 287 | }; |
| 288 | results.setParams(value: cast<OpResult>(Val: getBatch()), |
| 289 | params: makeI64Attrs(convolutionDims->batch)); |
| 290 | results.setParams(value: cast<OpResult>(Val: getOutputImage()), |
| 291 | params: makeI64Attrs(convolutionDims->outputImage)); |
| 292 | results.setParams(value: cast<OpResult>(Val: getOutputChannel()), |
| 293 | params: makeI64Attrs(convolutionDims->outputChannel)); |
| 294 | results.setParams(value: cast<OpResult>(Val: getFilterLoop()), |
| 295 | params: makeI64Attrs(convolutionDims->filterLoop)); |
| 296 | results.setParams(value: cast<OpResult>(Val: getInputChannel()), |
| 297 | params: makeI64Attrs(convolutionDims->inputChannel)); |
| 298 | results.setParams(value: cast<OpResult>(Val: getDepth()), |
| 299 | params: makeI64Attrs(convolutionDims->depth)); |
| 300 | |
| 301 | auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { |
| 302 | return llvm::to_vector( |
| 303 | Range: llvm::map_range(C&: values, F: [&](int64_t value) -> Attribute { |
| 304 | return builder.getI64IntegerAttr(value); |
| 305 | })); |
| 306 | }; |
| 307 | results.setParams(value: cast<OpResult>(Val: getStrides()), |
| 308 | params: makeI64AttrsFromI64(convolutionDims->strides)); |
| 309 | results.setParams(value: cast<OpResult>(Val: getDilations()), |
| 310 | params: makeI64AttrsFromI64(convolutionDims->dilations)); |
| 311 | return DiagnosedSilenceableFailure::success(); |
| 312 | } |
| 313 | |
| 314 | //===----------------------------------------------------------------------===// |
| 315 | // Utilities for structured match predicates. |
| 316 | //===----------------------------------------------------------------------===// |
| 317 | |
| 318 | /// Checks if all values from `list` are also contained in `reference`. Returns |
| 319 | /// a silenceable error with the given message at the given location when it is |
| 320 | /// not the case. The error message must contain the "{0}" placeholder that |
| 321 | /// will be substituted with the value from `list` that is not contained in |
| 322 | /// `reference`. |
| 323 | static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference, |
| 324 | ArrayRef<int64_t> list, |
| 325 | Location loc, |
| 326 | const char *message) { |
| 327 | for (int64_t value : list) { |
| 328 | if (llvm::any_of(Range&: reference, P: [&](unsigned ref) { |
| 329 | return static_cast<int64_t>(ref) == value; |
| 330 | })) { |
| 331 | continue; |
| 332 | } |
| 333 | return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value); |
| 334 | } |
| 335 | return DiagnosedSilenceableFailure::success(); |
| 336 | } |
| 337 | |
| 338 | //===----------------------------------------------------------------------===// |
| 339 | // MatchStructuredDimOp |
| 340 | //===----------------------------------------------------------------------===// |
| 341 | |
| 342 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( |
| 343 | Operation *current, transform::TransformResults &results, |
| 344 | transform::TransformState &state) { |
| 345 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 346 | SmallVector<int64_t> dimensions; |
| 347 | DiagnosedSilenceableFailure diag = getDimensionsFor(op: linalgOp, dims&: dimensions); |
| 348 | if (!diag.succeeded()) |
| 349 | return diag; |
| 350 | |
| 351 | // If asked to check for the kind of dimension, perform the check. |
| 352 | if (getParallel() || getReduction()) { |
| 353 | SmallVector<unsigned> reference; |
| 354 | if (getParallel()) |
| 355 | linalgOp.getParallelDims(res&: reference); |
| 356 | else if (getReduction()) |
| 357 | linalgOp.getReductionDims(res&: reference); |
| 358 | |
| 359 | DiagnosedSilenceableFailure diag = |
| 360 | containsAll(reference, list: dimensions, loc: getLoc(), |
| 361 | message: getParallel() ? "expects dimension #{0} to be parallel" |
| 362 | : "expects dimension #{0} to be reduction" ); |
| 363 | if (!diag.succeeded()) |
| 364 | return diag; |
| 365 | } |
| 366 | |
| 367 | // If not capturing, we are done here. |
| 368 | if (!getResult()) |
| 369 | return diag; |
| 370 | |
| 371 | SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges(); |
| 372 | Builder builder(current); |
| 373 | SmallVector<Attribute> captured = llvm::to_vector( |
| 374 | Range: llvm::map_range(C&: dimensions, F: [&](int64_t dim) -> Attribute { |
| 375 | return builder.getI64IntegerAttr(value: ranges[dim]); |
| 376 | })); |
| 377 | results.setParams(value: cast<OpResult>(Val: getResult()), params: captured); |
| 378 | return DiagnosedSilenceableFailure::success(); |
| 379 | } |
| 380 | |
| 381 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( |
| 382 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) { |
| 383 | DiagnosedSilenceableFailure diag = |
| 384 | expandTargetSpecification(loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), |
| 385 | rawList: getRawDimList(), maxNumber: op.getNumLoops(), result&: dims); |
| 386 | if (diag.isSilenceableFailure()) { |
| 387 | diag.attachNote(loc: op->getLoc()) |
| 388 | << "while considering dimensions of this payload operation" ; |
| 389 | } |
| 390 | return diag; |
| 391 | } |
| 392 | |
| 393 | LogicalResult transform::MatchStructuredDimOp::verify() { |
| 394 | if (getParallel() && getReduction()) { |
| 395 | return emitOpError() << "cannot request the same dimension to be both " |
| 396 | "parallel and reduction" ; |
| 397 | } |
| 398 | return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawDimList(), |
| 399 | inverted: getIsInverted(), all: getIsAll()); |
| 400 | } |
| 401 | |
| 402 | //===----------------------------------------------------------------------===// |
| 403 | // MatchStructuredElementalBitwidthOp |
| 404 | //===----------------------------------------------------------------------===// |
| 405 | |
| 406 | DiagnosedSilenceableFailure |
| 407 | transform::MatchStructuredElementalBitwidthOp::matchValue( |
| 408 | Value current, transform::TransformResults &results, |
| 409 | transform::TransformState &state) { |
| 410 | auto setupResult = [&](int64_t bitwidth) { |
| 411 | Attribute attr = Builder(current.getContext()).getI64IntegerAttr(value: bitwidth); |
| 412 | results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr}); |
| 413 | return DiagnosedSilenceableFailure::success(); |
| 414 | }; |
| 415 | |
| 416 | Type type = current.getType(); |
| 417 | if (type.isIntOrFloat()) |
| 418 | return setupResult(type.getIntOrFloatBitWidth()); |
| 419 | |
| 420 | if (auto shapedType = dyn_cast<ShapedType>(Val&: type)) { |
| 421 | if (shapedType.getElementType().isIntOrFloat()) |
| 422 | return setupResult(shapedType.getElementTypeBitWidth()); |
| 423 | } |
| 424 | return emitSilenceableError() |
| 425 | << "unsupported type for bitwidth extraction: " << type; |
| 426 | } |
| 427 | |
| 428 | //===----------------------------------------------------------------------===// |
| 429 | // MatchStructuredInputOp |
| 430 | //===----------------------------------------------------------------------===// |
| 431 | |
| 432 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( |
| 433 | Operation *current, transform::TransformResults &results, |
| 434 | transform::TransformState &state) { |
| 435 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 436 | SmallVector<int64_t> positions; |
| 437 | DiagnosedSilenceableFailure diag = getPositionsFor(op: linalgOp, positions); |
| 438 | if (!diag.succeeded()) |
| 439 | return diag; |
| 440 | |
| 441 | SmallVector<MappedValue> operandMapping; |
| 442 | operandMapping.reserve(N: positions.size()); |
| 443 | for (int64_t position : positions) { |
| 444 | AffineMap indexingMap = |
| 445 | linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInputOperand(i: position)); |
| 446 | if (getPermutation() && !indexingMap.isPermutation()) { |
| 447 | return emitSilenceableError() << "the indexing map for input #" |
| 448 | << position << " is not a permutation" ; |
| 449 | } |
| 450 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
| 451 | return emitSilenceableError() |
| 452 | << "the indexing map for input #" << position |
| 453 | << " is not a projected permutation" ; |
| 454 | } |
| 455 | |
| 456 | // If capture not requested, skip it. |
| 457 | if (!getResult()) |
| 458 | continue; |
| 459 | |
| 460 | if (isa<AffineMapParamType>(Val: getResult().getType())) { |
| 461 | operandMapping.emplace_back(Args: AffineMapAttr::get(value: indexingMap)); |
| 462 | continue; |
| 463 | } |
| 464 | |
| 465 | Value operand = linalgOp.getDpsInputOperand(i: position)->get(); |
| 466 | if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) { |
| 467 | operandMapping.emplace_back(Args&: operand); |
| 468 | continue; |
| 469 | } |
| 470 | |
| 471 | Operation *operandProducer = operand.getDefiningOp(); |
| 472 | if (!operandProducer) { |
| 473 | return emitSilenceableError() |
| 474 | << "input #" << position << " is not produced by an operation" ; |
| 475 | } |
| 476 | operandMapping.emplace_back(Args&: operandProducer); |
| 477 | } |
| 478 | if (getResult()) |
| 479 | results.setMappedValues(handle: cast<OpResult>(Val: getResult()), values: operandMapping); |
| 480 | return DiagnosedSilenceableFailure::success(); |
| 481 | } |
| 482 | |
| 483 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( |
| 484 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
| 485 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
| 486 | loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(), |
| 487 | maxNumber: op.getNumDpsInputs(), result&: positions); |
| 488 | if (diag.isSilenceableFailure()) { |
| 489 | diag.attachNote(loc: op->getLoc()) |
| 490 | << "while considering DPS inputs of this payload operation" ; |
| 491 | } |
| 492 | return diag; |
| 493 | } |
| 494 | |
| 495 | /// Verifies a matcher op for structured input or output, specifically the |
| 496 | /// attributes specifying the operand positions. |
| 497 | template <typename OpTy> |
| 498 | LogicalResult verifyStructuredOperandOp(OpTy op) { |
| 499 | if (op.getPermutation() && op.getProjectedPermutation()) { |
| 500 | return op.emitOpError() |
| 501 | << op.getPermutationAttrName() << " and " |
| 502 | << op.getProjectedPermutationAttrName() << " are mutually exclusive" ; |
| 503 | } |
| 504 | if (op.getRawPositionList().size() > 1 && op.getResult()) { |
| 505 | return op.emitOpError() |
| 506 | << "cannot bind multiple inputs/inits to the same value" ; |
| 507 | } |
| 508 | |
| 509 | return success(); |
| 510 | } |
| 511 | |
| 512 | LogicalResult transform::MatchStructuredInputOp::verify() { |
| 513 | if (failed(Result: verifyStructuredOperandOp(op: *this))) |
| 514 | return failure(); |
| 515 | return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(), |
| 516 | inverted: getIsInverted(), all: getIsAll()); |
| 517 | } |
| 518 | |
| 519 | //===----------------------------------------------------------------------===// |
| 520 | // MatchStructuredInitOp |
| 521 | //===----------------------------------------------------------------------===// |
| 522 | |
| 523 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( |
| 524 | Operation *current, transform::TransformResults &results, |
| 525 | transform::TransformState &state) { |
| 526 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 527 | SmallVector<int64_t> positions; |
| 528 | DiagnosedSilenceableFailure diag = getPositionsFor(op: linalgOp, positions); |
| 529 | if (!diag.succeeded()) |
| 530 | return diag; |
| 531 | |
| 532 | SmallVector<MappedValue> operandMapping; |
| 533 | operandMapping.reserve(N: positions.size()); |
| 534 | for (int64_t position : positions) { |
| 535 | AffineMap indexingMap = |
| 536 | linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInitOperand(i: position)); |
| 537 | if (getPermutation() && !indexingMap.isPermutation()) { |
| 538 | return emitSilenceableError() << "the indexing map for output(init) #" |
| 539 | << position << " is not a permutation" ; |
| 540 | } |
| 541 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
| 542 | return emitSilenceableError() << "the indexing map for output(init) #" |
| 543 | << position << " is not a permutation" ; |
| 544 | } |
| 545 | |
| 546 | // If capture not requested, skip it. |
| 547 | if (!getResult()) |
| 548 | continue; |
| 549 | |
| 550 | if (isa<AffineMapParamType>(Val: getResult().getType())) { |
| 551 | operandMapping.emplace_back(Args: AffineMapAttr::get(value: indexingMap)); |
| 552 | continue; |
| 553 | } |
| 554 | |
| 555 | Value operand = linalgOp.getDpsInitOperand(i: position)->get(); |
| 556 | if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) { |
| 557 | operandMapping.emplace_back(Args&: operand); |
| 558 | continue; |
| 559 | } |
| 560 | |
| 561 | Operation *operandProducer = operand.getDefiningOp(); |
| 562 | if (!operandProducer) { |
| 563 | return emitSilenceableError() << "output(init) #" << position |
| 564 | << " is not produced by an operation" ; |
| 565 | } |
| 566 | operandMapping.emplace_back(Args&: operandProducer); |
| 567 | } |
| 568 | if (getResult()) |
| 569 | results.setMappedValues(handle: cast<OpResult>(Val: getResult()), values: operandMapping); |
| 570 | return DiagnosedSilenceableFailure::success(); |
| 571 | } |
| 572 | |
| 573 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( |
| 574 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
| 575 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
| 576 | loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(), |
| 577 | maxNumber: op.getNumDpsInits(), result&: positions); |
| 578 | if (diag.isSilenceableFailure()) { |
| 579 | diag.attachNote(loc: op->getLoc()) |
| 580 | << "while considering DPS inits (outputs) of this payload operation" ; |
| 581 | } |
| 582 | return diag; |
| 583 | } |
| 584 | |
| 585 | LogicalResult transform::MatchStructuredInitOp::verify() { |
| 586 | if (failed(Result: verifyStructuredOperandOp(op: *this))) |
| 587 | return failure(); |
| 588 | return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(), |
| 589 | inverted: getIsInverted(), all: getIsAll()); |
| 590 | } |
| 591 | |
| 592 | //===----------------------------------------------------------------------===// |
| 593 | // MatchStructuredNumInputsOp |
| 594 | //===----------------------------------------------------------------------===// |
| 595 | |
| 596 | DiagnosedSilenceableFailure |
| 597 | transform::MatchStructuredNumInputsOp::matchOperation( |
| 598 | Operation *current, transform::TransformResults &results, |
| 599 | transform::TransformState &state) { |
| 600 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 601 | Attribute attr = |
| 602 | Builder(current).getI64IntegerAttr(value: linalgOp.getNumDpsInputs()); |
| 603 | results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr}); |
| 604 | return DiagnosedSilenceableFailure::success(); |
| 605 | } |
| 606 | |
| 607 | //===----------------------------------------------------------------------===// |
| 608 | // MatchStructuredNumInitsOp |
| 609 | //===----------------------------------------------------------------------===// |
| 610 | |
| 611 | DiagnosedSilenceableFailure |
| 612 | transform::MatchStructuredNumInitsOp::matchOperation( |
| 613 | Operation *current, transform::TransformResults &results, |
| 614 | transform::TransformState &state) { |
| 615 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 616 | Attribute attr = |
| 617 | Builder(current).getI64IntegerAttr(value: linalgOp.getNumDpsInits()); |
| 618 | results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr}); |
| 619 | return DiagnosedSilenceableFailure::success(); |
| 620 | } |
| 621 | |
| 622 | //===----------------------------------------------------------------------===// |
| 623 | // MatchStructuredRankOp |
| 624 | //===----------------------------------------------------------------------===// |
| 625 | |
| 626 | DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( |
| 627 | Operation *current, transform::TransformResults &results, |
| 628 | transform::TransformState &state) { |
| 629 | auto linalgOp = cast<linalg::LinalgOp>(Val: current); |
| 630 | int64_t numLoops = linalgOp.getNumLoops(); |
| 631 | Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(value: numLoops); |
| 632 | results.setParams(value: cast<OpResult>(Val: getRank()), params: {attr}); |
| 633 | return DiagnosedSilenceableFailure::success(); |
| 634 | } |
| 635 | |
| 636 | //===----------------------------------------------------------------------===// |
| 637 | // MatchStructuredResultOp |
| 638 | //===----------------------------------------------------------------------===// |
| 639 | |
| 640 | DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( |
| 641 | Operation *op, transform::TransformResults &results, |
| 642 | transform::TransformState &state) { |
| 643 | auto linalgOp = cast<linalg::LinalgOp>(Val: op); |
| 644 | int64_t position; |
| 645 | DiagnosedSilenceableFailure diag = getPositionFor(op: linalgOp, position); |
| 646 | if (!diag.succeeded()) |
| 647 | return diag; |
| 648 | |
| 649 | Value result = linalgOp.getTiedOpResult(opOperand: linalgOp.getDpsInitOperand(i: position)); |
| 650 | if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) { |
| 651 | results.setValues(handle: cast<OpResult>(Val: getResult()), values: {result}); |
| 652 | return DiagnosedSilenceableFailure::success(); |
| 653 | } |
| 654 | |
| 655 | if (result.getUsers().empty()) { |
| 656 | return emitSilenceableError() |
| 657 | << "no users of the result #" << getPosition(); |
| 658 | } |
| 659 | Operation *firstUser = *result.getUsers().begin(); |
| 660 | if (getAny()) { |
| 661 | results.set(value: cast<OpResult>(Val: getResult()), ops: {firstUser}); |
| 662 | return DiagnosedSilenceableFailure::success(); |
| 663 | } |
| 664 | if (getSingle()) { |
| 665 | if (!llvm::hasSingleElement(C: result.getUsers())) { |
| 666 | return emitSilenceableError() |
| 667 | << "more than one result user with single user requested" ; |
| 668 | } |
| 669 | results.set(value: cast<OpResult>(Val: getResult()), ops: {firstUser}); |
| 670 | return DiagnosedSilenceableFailure::success(); |
| 671 | } |
| 672 | |
| 673 | return emitDefiniteFailure() << "unknown sub-predicate" ; |
| 674 | } |
| 675 | |
| 676 | DiagnosedSilenceableFailure |
| 677 | transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, |
| 678 | int64_t &position) { |
| 679 | auto rawPosition = static_cast<int64_t>(getPosition()); |
| 680 | position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; |
| 681 | if (position >= op.getNumDpsInits() || position < 0) { |
| 682 | return emitSilenceableError() |
| 683 | << "position " << rawPosition |
| 684 | << " overflows the number of results(ints) of the payload operation" ; |
| 685 | } |
| 686 | return DiagnosedSilenceableFailure::success(); |
| 687 | } |
| 688 | |
| 689 | LogicalResult transform::MatchStructuredResultOp::verify() { |
| 690 | if ((getAny() || getSingle()) ^ |
| 691 | isa<TransformHandleTypeInterface>(Val: getResult().getType())) { |
| 692 | return emitOpError() << "expects either the any/single keyword or the type " |
| 693 | "value handle result type" ; |
| 694 | } |
| 695 | if (getAny() && getSingle()) { |
| 696 | return emitOpError() << "'any' and 'single' are mutually exclusive" ; |
| 697 | } |
| 698 | return success(); |
| 699 | } |
| 700 | |
| 701 | //===----------------------------------------------------------------------===// |
| 702 | // MatchStructuredYieldOp |
| 703 | //===----------------------------------------------------------------------===// |
| 704 | |
| 705 | void transform::MatchStructuredYieldOp::getEffects( |
| 706 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 707 | onlyReadsHandle(handles: getHandlesMutable(), effects); |
| 708 | onlyReadsPayload(effects); |
| 709 | } |
| 710 | |
| 711 | void transform::MatchStructuredYieldOp::build(OpBuilder &builder, |
| 712 | OperationState &state) { |
| 713 | build(odsBuilder&: builder, odsState&: state, handles: ValueRange()); |
| 714 | } |
| 715 | |
| 716 | #define GET_OP_CLASSES |
| 717 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" |
| 718 | |