| 1 | //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" |
| 10 | #include "mlir/Analysis/SliceAnalysis.h" |
| 11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 12 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
| 13 | #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" |
| 14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
| 16 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
| 17 | #include "mlir/IR/BuiltinAttributes.h" |
| 18 | #include "mlir/Interfaces/FunctionImplementation.h" |
| 19 | #include "llvm/Support/Debug.h" |
| 20 | #include "llvm/Support/FormatVariadic.h" |
| 21 | #include "llvm/Support/InterleavedRange.h" |
| 22 | |
| 23 | using namespace mlir; |
| 24 | |
| 25 | #define DEBUG_TYPE "linalg-transforms" |
| 26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 27 | |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | // StructuredMatchOp |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | |
| 32 | DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( |
| 33 | Operation *current, transform::TransformResults &results, |
| 34 | transform::TransformState &state) { |
| 35 | // First, check if the payload operation is a structured Linalg operation. |
| 36 | if (!isa<linalg::LinalgOp>(current)) { |
| 37 | if (getFailurePropagationMode().value_or( |
| 38 | FailurePropagationMode::Propagate) == |
| 39 | FailurePropagationMode::Propagate) { |
| 40 | return emitSilenceableError() << "expected a Linalg op" ; |
| 41 | } |
| 42 | // If errors are suppressed, succeed and set all results to empty lists. |
| 43 | LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op" ); |
| 44 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
| 45 | return DiagnosedSilenceableFailure::success(); |
| 46 | } |
| 47 | |
| 48 | // Bind `current` to the block argument. |
| 49 | auto scope = state.make_region_scope(getBodyRegion()); |
| 50 | if (failed(state.mapBlockArgument(getBody()->getArgument(0), |
| 51 | MappedValue(current)))) { |
| 52 | return DiagnosedSilenceableFailure::definiteFailure(); |
| 53 | } |
| 54 | |
| 55 | for (Operation &nested : getBody()->without_terminator()) { |
| 56 | DiagnosedSilenceableFailure diag = |
| 57 | state.applyTransform(cast<TransformOpInterface>(nested)); |
| 58 | if (diag.isDefiniteFailure()) |
| 59 | return diag; |
| 60 | if (diag.succeeded()) |
| 61 | continue; |
| 62 | |
| 63 | // If propagating errors, do this immediately. |
| 64 | assert(diag.isSilenceableFailure()); |
| 65 | if (getFailurePropagationMode().value_or( |
| 66 | FailurePropagationMode::Propagate) == |
| 67 | FailurePropagationMode::Propagate) { |
| 68 | return diag; |
| 69 | } |
| 70 | |
| 71 | // If suppressing errors, print the message into the debug stream before |
| 72 | // silencing it. Then set all results value that are already known. |
| 73 | // Results come from the terminator operands, which may be defined in the |
| 74 | // (single) block of this operation or above it. When they are defined |
| 75 | // above, they are known to be mapped at this point per SSA dominance. |
| 76 | // When they are defined in this block, we additionally check if we have |
| 77 | // already applied the operation that defines them. If not, the |
| 78 | // corresponding results will be set to empty lists. |
| 79 | LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() |
| 80 | << "\n" ); |
| 81 | (void)diag.silence(); |
| 82 | SmallVector<OpOperand *> undefinedOperands; |
| 83 | for (OpOperand &terminatorOperand : |
| 84 | getBody()->getTerminator()->getOpOperands()) { |
| 85 | Operation *definingOp = terminatorOperand.get().getDefiningOp(); |
| 86 | if (!definingOp) |
| 87 | continue; |
| 88 | if (definingOp->getBlock() != getBody()) |
| 89 | continue; |
| 90 | if (definingOp->isBeforeInBlock(&nested)) |
| 91 | continue; |
| 92 | |
| 93 | undefinedOperands.push_back(&terminatorOperand); |
| 94 | } |
| 95 | |
| 96 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
| 97 | auto filtered = llvm::make_filter_range( |
| 98 | getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { |
| 99 | return !llvm::is_contained(undefinedOperands, &opOperand); |
| 100 | }); |
| 101 | SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range( |
| 102 | filtered, [](OpOperand &opOperand) { return opOperand.get(); })); |
| 103 | detail::prepareValueMappings(mappings, definedOperands, state); |
| 104 | for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { |
| 105 | results.setMappedValues(getResults()[operand.getOperandNumber()], |
| 106 | mapping); |
| 107 | } |
| 108 | results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); |
| 109 | return DiagnosedSilenceableFailure::success(); |
| 110 | } |
| 111 | |
| 112 | // Set the results. |
| 113 | detail::forwardTerminatorOperands(getBody(), state, results); |
| 114 | return DiagnosedSilenceableFailure::success(); |
| 115 | } |
| 116 | |
| 117 | void transform::MatchStructuredOp::getEffects( |
| 118 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 119 | onlyReadsHandle(getCurrentMutable(), effects); |
| 120 | onlyReadsPayload(effects); |
| 121 | producesHandle(getOperation()->getOpResults(), effects); |
| 122 | } |
| 123 | |
| 124 | LogicalResult transform::MatchStructuredOp::verify() { |
| 125 | if (getBody()->getNumArguments() != 1) |
| 126 | return emitOpError() << "expected one body argument" ; |
| 127 | if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) { |
| 128 | return emitOpError() << "expected body argument to implement " |
| 129 | "TransformHandleTypeInterface" ; |
| 130 | } |
| 131 | for (Operation &nested : getBody()->without_terminator()) { |
| 132 | if (isa<MatchOpInterface>(nested)) |
| 133 | continue; |
| 134 | InFlightDiagnostic diag = |
| 135 | emitOpError() |
| 136 | << "expects nested operations to implement MatchOpInterface" ; |
| 137 | diag.attachNote(nested.getLoc()) << "offending operation" ; |
| 138 | return diag; |
| 139 | } |
| 140 | return success(); |
| 141 | } |
| 142 | |
| 143 | //===----------------------------------------------------------------------===// |
| 144 | // StructuredOpPredicateOpTrait |
| 145 | //===----------------------------------------------------------------------===// |
| 146 | |
| 147 | LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( |
| 148 | Operation *op, Value structuredOpHandle) { |
| 149 | if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) { |
| 150 | return op->emitOpError() << "expects parent op to be '" |
| 151 | << MatchStructuredOp::getOperationName() << "'" ; |
| 152 | } |
| 153 | |
| 154 | // Bail out here, let the verifier of the parent complain. |
| 155 | Operation *parent = op->getParentOp(); |
| 156 | if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() || |
| 157 | parent->getRegion(index: 0).front().getNumArguments() < 1) |
| 158 | return success(); |
| 159 | |
| 160 | if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) { |
| 161 | return op->emitOpError() |
| 162 | << "expected predicate to apply to the surrounding structured op" ; |
| 163 | } |
| 164 | return success(); |
| 165 | } |
| 166 | |
| 167 | //===----------------------------------------------------------------------===// |
| 168 | // MatchStructuredBodyOp |
| 169 | //===----------------------------------------------------------------------===// |
| 170 | |
| 171 | DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( |
| 172 | Operation *current, transform::TransformResults &results, |
| 173 | transform::TransformState &state) { |
| 174 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 175 | if (std::optional<uint64_t> position = getReductionPosition()) { |
| 176 | SmallVector<Operation *> combinerOps; |
| 177 | if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, |
| 178 | combinerOps)) { |
| 179 | return emitSilenceableError() << "could not match reduction" ; |
| 180 | } |
| 181 | if (combinerOps.size() != 1) { |
| 182 | return emitSilenceableError() << "reduction combiner is not a single op" ; |
| 183 | } |
| 184 | return DiagnosedSilenceableFailure::success(); |
| 185 | } |
| 186 | if (getPassthrough()) { |
| 187 | Block &body = linalgOp->getRegion(0).front(); |
| 188 | if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { |
| 189 | return emitSilenceableError() << "not a passthrough" ; |
| 190 | } |
| 191 | return DiagnosedSilenceableFailure::success(); |
| 192 | } |
| 193 | if (getElementwise()) { |
| 194 | if (!isElementwise(linalgOp)) |
| 195 | return emitSilenceableError() << "not elementwise" ; |
| 196 | return DiagnosedSilenceableFailure::success(); |
| 197 | } |
| 198 | if (std::optional<ArrayAttr> contractionOps = getContraction()) { |
| 199 | Block &body = linalgOp->getRegion(0).front(); |
| 200 | std::string message; |
| 201 | llvm::raw_string_ostream os(message); |
| 202 | bool result = linalg::detail::isContractionBody( |
| 203 | body, |
| 204 | [&](Operation *elem, Operation *red) { |
| 205 | return elem->getName().getStringRef() == |
| 206 | cast<StringAttr>((*contractionOps)[0]).getValue() && |
| 207 | red->getName().getStringRef() == |
| 208 | cast<StringAttr>((*contractionOps)[1]).getValue(); |
| 209 | }, |
| 210 | os); |
| 211 | if (result) |
| 212 | return DiagnosedSilenceableFailure::success(); |
| 213 | return emitSilenceableError() << "contraction: " << message; |
| 214 | } |
| 215 | return emitDefiniteFailure() << "unknown body condition" ; |
| 216 | } |
| 217 | |
| 218 | LogicalResult transform::MatchStructuredBodyOp::verify() { |
| 219 | int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + |
| 220 | getElementwise() + getContraction().has_value(); |
| 221 | |
| 222 | if (numOptions > 1) { |
| 223 | StringAttr attributeNames[] = { |
| 224 | getReductionPositionAttrName(), getPassthroughAttrName(), |
| 225 | getElementwiseAttrName(), getContractionAttrName()}; |
| 226 | return emitOpError() << "only one of {" << llvm::interleaved(attributeNames) |
| 227 | << "} is allowed" ; |
| 228 | } |
| 229 | |
| 230 | if (std::optional<ArrayAttr> contractionAttr = getContraction()) { |
| 231 | if (contractionAttr->size() != 2) { |
| 232 | return emitOpError() << "expects " << getContractionAttrName() |
| 233 | << " to contain two elements" ; |
| 234 | } |
| 235 | } |
| 236 | return success(); |
| 237 | } |
| 238 | |
| 239 | //===----------------------------------------------------------------------===// |
| 240 | // MatchStructuredClassifyContractionDimsOp |
| 241 | //===----------------------------------------------------------------------===// |
| 242 | |
| 243 | DiagnosedSilenceableFailure |
| 244 | transform::MatchStructuredClassifyContractionDimsOp::matchOperation( |
| 245 | Operation *current, transform::TransformResults &results, |
| 246 | transform::TransformState &state) { |
| 247 | FailureOr<linalg::ContractionDimensions> contractionDims = |
| 248 | linalg::inferContractionDims(cast<linalg::LinalgOp>(current)); |
| 249 | if (failed(contractionDims)) |
| 250 | return emitSilenceableError() << "could not infer contraction dimensions" ; |
| 251 | |
| 252 | MLIRContext *context = current->getContext(); |
| 253 | Builder builder(context); |
| 254 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
| 255 | return llvm::to_vector( |
| 256 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
| 257 | return builder.getI64IntegerAttr(value); |
| 258 | })); |
| 259 | }; |
| 260 | results.setParams(cast<OpResult>(getBatch()), |
| 261 | makeI64Attrs(contractionDims->batch)); |
| 262 | results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m)); |
| 263 | results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n)); |
| 264 | results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k)); |
| 265 | return DiagnosedSilenceableFailure::success(); |
| 266 | } |
| 267 | |
| 268 | //===----------------------------------------------------------------------===// |
| 269 | // MatchStructuredClassifyConvolutionDimsOp |
| 270 | //===----------------------------------------------------------------------===// |
| 271 | |
| 272 | DiagnosedSilenceableFailure |
| 273 | transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( |
| 274 | Operation *current, transform::TransformResults &results, |
| 275 | transform::TransformState &state) { |
| 276 | FailureOr<linalg::ConvolutionDimensions> convolutionDims = |
| 277 | linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current)); |
| 278 | if (failed(convolutionDims)) |
| 279 | return emitSilenceableError() << "could not infer convolution dimensions" ; |
| 280 | |
| 281 | MLIRContext *context = current->getContext(); |
| 282 | Builder builder(context); |
| 283 | auto makeI64Attrs = [&](ArrayRef<unsigned> values) { |
| 284 | return llvm::to_vector( |
| 285 | llvm::map_range(values, [&](unsigned value) -> Attribute { |
| 286 | return builder.getI64IntegerAttr(value); |
| 287 | })); |
| 288 | }; |
| 289 | results.setParams(cast<OpResult>(getBatch()), |
| 290 | makeI64Attrs(convolutionDims->batch)); |
| 291 | results.setParams(cast<OpResult>(getOutputImage()), |
| 292 | makeI64Attrs(convolutionDims->outputImage)); |
| 293 | results.setParams(cast<OpResult>(getOutputChannel()), |
| 294 | makeI64Attrs(convolutionDims->outputChannel)); |
| 295 | results.setParams(cast<OpResult>(getFilterLoop()), |
| 296 | makeI64Attrs(convolutionDims->filterLoop)); |
| 297 | results.setParams(cast<OpResult>(getInputChannel()), |
| 298 | makeI64Attrs(convolutionDims->inputChannel)); |
| 299 | results.setParams(cast<OpResult>(getDepth()), |
| 300 | makeI64Attrs(convolutionDims->depth)); |
| 301 | |
| 302 | auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { |
| 303 | return llvm::to_vector( |
| 304 | llvm::map_range(values, [&](int64_t value) -> Attribute { |
| 305 | return builder.getI64IntegerAttr(value); |
| 306 | })); |
| 307 | }; |
| 308 | results.setParams(cast<OpResult>(getStrides()), |
| 309 | makeI64AttrsFromI64(convolutionDims->strides)); |
| 310 | results.setParams(cast<OpResult>(getDilations()), |
| 311 | makeI64AttrsFromI64(convolutionDims->dilations)); |
| 312 | return DiagnosedSilenceableFailure::success(); |
| 313 | } |
| 314 | |
| 315 | //===----------------------------------------------------------------------===// |
| 316 | // Utilities for structured match predicates. |
| 317 | //===----------------------------------------------------------------------===// |
| 318 | |
| 319 | /// Checks if all values from `list` are also contained in `reference`. Returns |
| 320 | /// a silenceable error with the given message at the given location when it is |
| 321 | /// not the case. The error message must contain the "{0}" placeholder that |
| 322 | /// will be substituted with the value from `list` that is not contained in |
| 323 | /// `reference`. |
| 324 | static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference, |
| 325 | ArrayRef<int64_t> list, |
| 326 | Location loc, |
| 327 | const char *message) { |
| 328 | for (int64_t value : list) { |
| 329 | if (llvm::any_of(Range&: reference, P: [&](unsigned ref) { |
| 330 | return static_cast<int64_t>(ref) == value; |
| 331 | })) { |
| 332 | continue; |
| 333 | } |
| 334 | return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value); |
| 335 | } |
| 336 | return DiagnosedSilenceableFailure::success(); |
| 337 | } |
| 338 | |
| 339 | //===----------------------------------------------------------------------===// |
| 340 | // MatchStructuredDimOp |
| 341 | //===----------------------------------------------------------------------===// |
| 342 | |
| 343 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( |
| 344 | Operation *current, transform::TransformResults &results, |
| 345 | transform::TransformState &state) { |
| 346 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 347 | SmallVector<int64_t> dimensions; |
| 348 | DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); |
| 349 | if (!diag.succeeded()) |
| 350 | return diag; |
| 351 | |
| 352 | // If asked to check for the kind of dimension, perform the check. |
| 353 | if (getParallel() || getReduction()) { |
| 354 | SmallVector<unsigned> reference; |
| 355 | if (getParallel()) |
| 356 | linalgOp.getParallelDims(reference); |
| 357 | else if (getReduction()) |
| 358 | linalgOp.getReductionDims(reference); |
| 359 | |
| 360 | DiagnosedSilenceableFailure diag = |
| 361 | containsAll(reference, dimensions, getLoc(), |
| 362 | getParallel() ? "expects dimension #{0} to be parallel" |
| 363 | : "expects dimension #{0} to be reduction" ); |
| 364 | if (!diag.succeeded()) |
| 365 | return diag; |
| 366 | } |
| 367 | |
| 368 | // If not capturing, we are done here. |
| 369 | if (!getResult()) |
| 370 | return diag; |
| 371 | |
| 372 | SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges(); |
| 373 | Builder builder(current); |
| 374 | SmallVector<Attribute> captured = llvm::to_vector( |
| 375 | llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { |
| 376 | return builder.getI64IntegerAttr(ranges[dim]); |
| 377 | })); |
| 378 | results.setParams(cast<OpResult>(getResult()), captured); |
| 379 | return DiagnosedSilenceableFailure::success(); |
| 380 | } |
| 381 | |
| 382 | DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( |
| 383 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) { |
| 384 | DiagnosedSilenceableFailure diag = |
| 385 | expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), |
| 386 | getRawDimList(), op.getNumLoops(), dims); |
| 387 | if (diag.isSilenceableFailure()) { |
| 388 | diag.attachNote(op->getLoc()) |
| 389 | << "while considering dimensions of this payload operation" ; |
| 390 | } |
| 391 | return diag; |
| 392 | } |
| 393 | |
| 394 | LogicalResult transform::MatchStructuredDimOp::verify() { |
| 395 | if (getParallel() && getReduction()) { |
| 396 | return emitOpError() << "cannot request the same dimension to be both " |
| 397 | "parallel and reduction" ; |
| 398 | } |
| 399 | return verifyTransformMatchDimsOp(getOperation(), getRawDimList(), |
| 400 | getIsInverted(), getIsAll()); |
| 401 | } |
| 402 | |
| 403 | //===----------------------------------------------------------------------===// |
| 404 | // MatchStructuredElementalBitwidthOp |
| 405 | //===----------------------------------------------------------------------===// |
| 406 | |
| 407 | DiagnosedSilenceableFailure |
| 408 | transform::MatchStructuredElementalBitwidthOp::matchValue( |
| 409 | Value current, transform::TransformResults &results, |
| 410 | transform::TransformState &state) { |
| 411 | auto setupResult = [&](int64_t bitwidth) { |
| 412 | Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); |
| 413 | results.setParams(cast<OpResult>(getResult()), {attr}); |
| 414 | return DiagnosedSilenceableFailure::success(); |
| 415 | }; |
| 416 | |
| 417 | Type type = current.getType(); |
| 418 | if (type.isIntOrFloat()) |
| 419 | return setupResult(type.getIntOrFloatBitWidth()); |
| 420 | |
| 421 | if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| 422 | if (shapedType.getElementType().isIntOrFloat()) |
| 423 | return setupResult(shapedType.getElementTypeBitWidth()); |
| 424 | } |
| 425 | return emitSilenceableError() |
| 426 | << "unsupported type for bitwidth extraction: " << type; |
| 427 | } |
| 428 | |
| 429 | //===----------------------------------------------------------------------===// |
| 430 | // MatchStructuredInputOp |
| 431 | //===----------------------------------------------------------------------===// |
| 432 | |
| 433 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( |
| 434 | Operation *current, transform::TransformResults &results, |
| 435 | transform::TransformState &state) { |
| 436 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 437 | SmallVector<int64_t> positions; |
| 438 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
| 439 | if (!diag.succeeded()) |
| 440 | return diag; |
| 441 | |
| 442 | SmallVector<MappedValue> operandMapping; |
| 443 | operandMapping.reserve(positions.size()); |
| 444 | for (int64_t position : positions) { |
| 445 | AffineMap indexingMap = |
| 446 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); |
| 447 | if (getPermutation() && !indexingMap.isPermutation()) { |
| 448 | return emitSilenceableError() << "the indexing map for input #" |
| 449 | << position << " is not a permutation" ; |
| 450 | } |
| 451 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
| 452 | return emitSilenceableError() |
| 453 | << "the indexing map for input #" << position |
| 454 | << " is not a projected permutation" ; |
| 455 | } |
| 456 | |
| 457 | // If capture not requested, skip it. |
| 458 | if (!getResult()) |
| 459 | continue; |
| 460 | |
| 461 | if (isa<AffineMapParamType>(getResult().getType())) { |
| 462 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
| 463 | continue; |
| 464 | } |
| 465 | |
| 466 | Value operand = linalgOp.getDpsInputOperand(position)->get(); |
| 467 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
| 468 | operandMapping.emplace_back(operand); |
| 469 | continue; |
| 470 | } |
| 471 | |
| 472 | Operation *operandProducer = operand.getDefiningOp(); |
| 473 | if (!operandProducer) { |
| 474 | return emitSilenceableError() |
| 475 | << "input #" << position << " is not produced by an operation" ; |
| 476 | } |
| 477 | operandMapping.emplace_back(operandProducer); |
| 478 | } |
| 479 | if (getResult()) |
| 480 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
| 481 | return DiagnosedSilenceableFailure::success(); |
| 482 | } |
| 483 | |
| 484 | DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( |
| 485 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
| 486 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
| 487 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
| 488 | op.getNumDpsInputs(), positions); |
| 489 | if (diag.isSilenceableFailure()) { |
| 490 | diag.attachNote(op->getLoc()) |
| 491 | << "while considering DPS inputs of this payload operation" ; |
| 492 | } |
| 493 | return diag; |
| 494 | } |
| 495 | |
| 496 | /// Verifies a matcher op for structured input or output, specifically the |
| 497 | /// attributes specifying the operand positions. |
| 498 | template <typename OpTy> |
| 499 | LogicalResult verifyStructuredOperandOp(OpTy op) { |
| 500 | if (op.getPermutation() && op.getProjectedPermutation()) { |
| 501 | return op.emitOpError() |
| 502 | << op.getPermutationAttrName() << " and " |
| 503 | << op.getProjectedPermutationAttrName() << " are mutually exclusive" ; |
| 504 | } |
| 505 | if (op.getRawPositionList().size() > 1 && op.getResult()) { |
| 506 | return op.emitOpError() |
| 507 | << "cannot bind multiple inputs/inits to the same value" ; |
| 508 | } |
| 509 | |
| 510 | return success(); |
| 511 | } |
| 512 | |
| 513 | LogicalResult transform::MatchStructuredInputOp::verify() { |
| 514 | if (failed(verifyStructuredOperandOp(*this))) |
| 515 | return failure(); |
| 516 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
| 517 | getIsInverted(), getIsAll()); |
| 518 | } |
| 519 | |
| 520 | //===----------------------------------------------------------------------===// |
| 521 | // MatchStructuredInitOp |
| 522 | //===----------------------------------------------------------------------===// |
| 523 | |
| 524 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( |
| 525 | Operation *current, transform::TransformResults &results, |
| 526 | transform::TransformState &state) { |
| 527 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 528 | SmallVector<int64_t> positions; |
| 529 | DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); |
| 530 | if (!diag.succeeded()) |
| 531 | return diag; |
| 532 | |
| 533 | SmallVector<MappedValue> operandMapping; |
| 534 | operandMapping.reserve(positions.size()); |
| 535 | for (int64_t position : positions) { |
| 536 | AffineMap indexingMap = |
| 537 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); |
| 538 | if (getPermutation() && !indexingMap.isPermutation()) { |
| 539 | return emitSilenceableError() << "the indexing map for output(init) #" |
| 540 | << position << " is not a permutation" ; |
| 541 | } |
| 542 | if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { |
| 543 | return emitSilenceableError() << "the indexing map for output(init) #" |
| 544 | << position << " is not a permutation" ; |
| 545 | } |
| 546 | |
| 547 | // If capture not requested, skip it. |
| 548 | if (!getResult()) |
| 549 | continue; |
| 550 | |
| 551 | if (isa<AffineMapParamType>(getResult().getType())) { |
| 552 | operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); |
| 553 | continue; |
| 554 | } |
| 555 | |
| 556 | Value operand = linalgOp.getDpsInitOperand(position)->get(); |
| 557 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
| 558 | operandMapping.emplace_back(operand); |
| 559 | continue; |
| 560 | } |
| 561 | |
| 562 | Operation *operandProducer = operand.getDefiningOp(); |
| 563 | if (!operandProducer) { |
| 564 | return emitSilenceableError() << "output(init) #" << position |
| 565 | << " is not produced by an operation" ; |
| 566 | } |
| 567 | operandMapping.emplace_back(operandProducer); |
| 568 | } |
| 569 | if (getResult()) |
| 570 | results.setMappedValues(cast<OpResult>(getResult()), operandMapping); |
| 571 | return DiagnosedSilenceableFailure::success(); |
| 572 | } |
| 573 | |
| 574 | DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( |
| 575 | linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { |
| 576 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
| 577 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
| 578 | op.getNumDpsInits(), positions); |
| 579 | if (diag.isSilenceableFailure()) { |
| 580 | diag.attachNote(op->getLoc()) |
| 581 | << "while considering DPS inits (outputs) of this payload operation" ; |
| 582 | } |
| 583 | return diag; |
| 584 | } |
| 585 | |
| 586 | LogicalResult transform::MatchStructuredInitOp::verify() { |
| 587 | if (failed(verifyStructuredOperandOp(*this))) |
| 588 | return failure(); |
| 589 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
| 590 | getIsInverted(), getIsAll()); |
| 591 | } |
| 592 | |
| 593 | //===----------------------------------------------------------------------===// |
| 594 | // MatchStructuredNumInputsOp |
| 595 | //===----------------------------------------------------------------------===// |
| 596 | |
| 597 | DiagnosedSilenceableFailure |
| 598 | transform::MatchStructuredNumInputsOp::matchOperation( |
| 599 | Operation *current, transform::TransformResults &results, |
| 600 | transform::TransformState &state) { |
| 601 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 602 | Attribute attr = |
| 603 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); |
| 604 | results.setParams(cast<OpResult>(getResult()), {attr}); |
| 605 | return DiagnosedSilenceableFailure::success(); |
| 606 | } |
| 607 | |
| 608 | //===----------------------------------------------------------------------===// |
| 609 | // MatchStructuredNumInitsOp |
| 610 | //===----------------------------------------------------------------------===// |
| 611 | |
| 612 | DiagnosedSilenceableFailure |
| 613 | transform::MatchStructuredNumInitsOp::matchOperation( |
| 614 | Operation *current, transform::TransformResults &results, |
| 615 | transform::TransformState &state) { |
| 616 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 617 | Attribute attr = |
| 618 | Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); |
| 619 | results.setParams(cast<OpResult>(getResult()), {attr}); |
| 620 | return DiagnosedSilenceableFailure::success(); |
| 621 | } |
| 622 | |
| 623 | //===----------------------------------------------------------------------===// |
| 624 | // MatchStructuredRankOp |
| 625 | //===----------------------------------------------------------------------===// |
| 626 | |
| 627 | DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( |
| 628 | Operation *current, transform::TransformResults &results, |
| 629 | transform::TransformState &state) { |
| 630 | auto linalgOp = cast<linalg::LinalgOp>(current); |
| 631 | int64_t numLoops = linalgOp.getNumLoops(); |
| 632 | Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); |
| 633 | results.setParams(cast<OpResult>(getRank()), {attr}); |
| 634 | return DiagnosedSilenceableFailure::success(); |
| 635 | } |
| 636 | |
| 637 | //===----------------------------------------------------------------------===// |
| 638 | // MatchStructuredResultOp |
| 639 | //===----------------------------------------------------------------------===// |
| 640 | |
| 641 | DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( |
| 642 | Operation *op, transform::TransformResults &results, |
| 643 | transform::TransformState &state) { |
| 644 | auto linalgOp = cast<linalg::LinalgOp>(op); |
| 645 | int64_t position; |
| 646 | DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); |
| 647 | if (!diag.succeeded()) |
| 648 | return diag; |
| 649 | |
| 650 | Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); |
| 651 | if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { |
| 652 | results.setValues(cast<OpResult>(getResult()), {result}); |
| 653 | return DiagnosedSilenceableFailure::success(); |
| 654 | } |
| 655 | |
| 656 | if (result.getUsers().empty()) { |
| 657 | return emitSilenceableError() |
| 658 | << "no users of the result #" << getPosition(); |
| 659 | } |
| 660 | Operation *firstUser = *result.getUsers().begin(); |
| 661 | if (getAny()) { |
| 662 | results.set(cast<OpResult>(getResult()), {firstUser}); |
| 663 | return DiagnosedSilenceableFailure::success(); |
| 664 | } |
| 665 | if (getSingle()) { |
| 666 | if (!llvm::hasSingleElement(result.getUsers())) { |
| 667 | return emitSilenceableError() |
| 668 | << "more than one result user with single user requested" ; |
| 669 | } |
| 670 | results.set(cast<OpResult>(getResult()), {firstUser}); |
| 671 | return DiagnosedSilenceableFailure::success(); |
| 672 | } |
| 673 | |
| 674 | return emitDefiniteFailure() << "unknown sub-predicate" ; |
| 675 | } |
| 676 | |
| 677 | DiagnosedSilenceableFailure |
| 678 | transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, |
| 679 | int64_t &position) { |
| 680 | auto rawPosition = static_cast<int64_t>(getPosition()); |
| 681 | position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; |
| 682 | if (position >= op.getNumDpsInits() || position < 0) { |
| 683 | return emitSilenceableError() |
| 684 | << "position " << rawPosition |
| 685 | << " overflows the number of results(ints) of the payload operation" ; |
| 686 | } |
| 687 | return DiagnosedSilenceableFailure::success(); |
| 688 | } |
| 689 | |
| 690 | LogicalResult transform::MatchStructuredResultOp::verify() { |
| 691 | if ((getAny() || getSingle()) ^ |
| 692 | isa<TransformHandleTypeInterface>(getResult().getType())) { |
| 693 | return emitOpError() << "expects either the any/single keyword or the type " |
| 694 | "value handle result type" ; |
| 695 | } |
| 696 | if (getAny() && getSingle()) { |
| 697 | return emitOpError() << "'any' and 'single' are mutually exclusive" ; |
| 698 | } |
| 699 | return success(); |
| 700 | } |
| 701 | |
| 702 | //===----------------------------------------------------------------------===// |
| 703 | // MatchStructuredYieldOp |
| 704 | //===----------------------------------------------------------------------===// |
| 705 | |
| 706 | void transform::MatchStructuredYieldOp::getEffects( |
| 707 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 708 | onlyReadsHandle(getHandlesMutable(), effects); |
| 709 | onlyReadsPayload(effects); |
| 710 | } |
| 711 | |
| 712 | void transform::MatchStructuredYieldOp::build(OpBuilder &builder, |
| 713 | OperationState &state) { |
| 714 | build(builder, state, ValueRange()); |
| 715 | } |
| 716 | |
| 717 | #define GET_OP_CLASSES |
| 718 | #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" |
| 719 | |