| 1 | //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 12 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
| 13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 14 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
| 20 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 21 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 23 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 26 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 27 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 29 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 30 | #include "mlir/IR/AffineExpr.h" |
| 31 | #include "mlir/IR/Builders.h" |
| 32 | #include "mlir/IR/BuiltinAttributes.h" |
| 33 | #include "mlir/IR/IRMapping.h" |
| 34 | #include "mlir/IR/MLIRContext.h" |
| 35 | #include "mlir/IR/OpDefinition.h" |
| 36 | #include "mlir/IR/Visitors.h" |
| 37 | #include "mlir/Support/LLVM.h" |
| 38 | #include "mlir/Transforms/DialectConversion.h" |
| 39 | #include "llvm/ADT/STLExtras.h" |
| 40 | #include "llvm/ADT/SmallVector.h" |
| 41 | #include "llvm/ADT/TypeSwitch.h" |
| 42 | #include "llvm/Support/Debug.h" |
| 43 | #include "llvm/Support/ErrorHandling.h" |
| 44 | #include "llvm/Support/InterleavedRange.h" |
| 45 | #include <type_traits> |
| 46 | |
| 47 | using namespace mlir; |
| 48 | using namespace mlir::gpu; |
| 49 | using namespace mlir::transform; |
| 50 | using namespace mlir::transform::gpu; |
| 51 | |
| 52 | #define DEBUG_TYPE "gpu-transforms" |
| 53 | #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" |
| 54 | |
| 55 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 56 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 57 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
| 58 | |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | // Apply...ConversionPatternsOp |
| 61 | //===----------------------------------------------------------------------===// |
| 62 | |
| 63 | void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( |
| 64 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 65 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 66 | // NVVM uses alloca in the default address space to represent private |
| 67 | // memory allocations, so drop private annotations. NVVM uses address |
| 68 | // space 3 for shared memory. NVVM uses the default address space to |
| 69 | // represent global memory. |
| 70 | // Used in populateGpuToNVVMConversionPatternsso attaching here for now. |
| 71 | // TODO: We should have a single to_nvvm_type_converter. |
| 72 | populateGpuMemorySpaceAttributeConversions( |
| 73 | llvmTypeConverter, [](AddressSpace space) -> unsigned { |
| 74 | switch (space) { |
| 75 | case AddressSpace::Global: |
| 76 | return static_cast<unsigned>( |
| 77 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
| 78 | case AddressSpace::Workgroup: |
| 79 | return static_cast<unsigned>( |
| 80 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
| 81 | case AddressSpace::Private: |
| 82 | return 0; |
| 83 | } |
| 84 | llvm_unreachable("unknown address space enum value" ); |
| 85 | return 0; |
| 86 | }); |
| 87 | // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. |
| 88 | // TODO: We should have a single to_nvvm_type_converter. |
| 89 | llvmTypeConverter.addConversion( |
| 90 | [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); |
| 91 | // Set higher benefit, so patterns will run before generic LLVM lowering. |
| 92 | populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns, |
| 93 | getBenefit()); |
| 94 | } |
| 95 | |
| 96 | LogicalResult |
| 97 | transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
| 98 | transform::TypeConverterBuilderOpInterface builder) { |
| 99 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 100 | return emitOpError("expected LLVMTypeConverter" ); |
| 101 | return success(); |
| 102 | } |
| 103 | |
| 104 | void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( |
| 105 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 106 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 107 | populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns); |
| 108 | } |
| 109 | |
| 110 | LogicalResult |
| 111 | transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( |
| 112 | transform::TypeConverterBuilderOpInterface builder) { |
| 113 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 114 | return emitOpError("expected LLVMTypeConverter" ); |
| 115 | return success(); |
| 116 | } |
| 117 | |
| 118 | void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
| 119 | populatePatterns(TypeConverter &typeConverter, |
| 120 | RewritePatternSet &patterns) { |
| 121 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 122 | populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns); |
| 123 | } |
| 124 | |
| 125 | LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
| 126 | verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { |
| 127 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 128 | return emitOpError("expected LLVMTypeConverter" ); |
| 129 | return success(); |
| 130 | } |
| 131 | |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | // Apply...PatternsOp |
| 134 | //===----------------------------------------------------------------------===//s |
| 135 | |
| 136 | void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { |
| 137 | populateGpuRewritePatterns(patterns); |
| 138 | } |
| 139 | |
| 140 | void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns( |
| 141 | RewritePatternSet &patterns) { |
| 142 | populateGpuPromoteShuffleToAMDGPUPatterns(patterns); |
| 143 | } |
| 144 | |
| 145 | //===----------------------------------------------------------------------===// |
| 146 | // ApplyUnrollVectorsSubgroupMmaOp |
| 147 | //===----------------------------------------------------------------------===// |
| 148 | |
| 149 | /// Pick an unrolling order that will allow tensorcore operation to reuse LHS |
| 150 | /// register. |
| 151 | static std::optional<SmallVector<int64_t>> |
| 152 | gpuMmaUnrollOrder(vector::ContractionOp contract) { |
| 153 | SmallVector<int64_t> order; |
| 154 | // First make reduction the outer dimensions. |
| 155 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
| 156 | if (vector::isReductionIterator(iter)) { |
| 157 | order.push_back(index); |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | llvm::SmallDenseSet<int64_t> dims; |
| 162 | for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { |
| 163 | dims.insert(cast<AffineDimExpr>(expr).getPosition()); |
| 164 | } |
| 165 | // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. |
| 166 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
| 167 | if (vector::isParallelIterator(iter) && dims.count(index)) { |
| 168 | order.push_back(index); |
| 169 | } |
| 170 | } |
| 171 | // Then the remaining parallel loops. |
| 172 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
| 173 | if (vector::isParallelIterator(iter) && !dims.count(index)) { |
| 174 | order.push_back(index); |
| 175 | } |
| 176 | } |
| 177 | return order; |
| 178 | } |
| 179 | |
| 180 | /// Returns the target vector size for the target operation based on the native |
| 181 | /// vector size specified with `m`, `n`, and `k`. |
| 182 | static std::optional<SmallVector<int64_t>> |
| 183 | getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { |
| 184 | if (auto contract = dyn_cast<vector::ContractionOp>(op)) { |
| 185 | int64_t contractRank = contract.getIteratorTypes().size(); |
| 186 | if (contractRank < 3) |
| 187 | return std::nullopt; |
| 188 | SmallVector<int64_t> nativeSize(contractRank - 3, 1); |
| 189 | nativeSize.append(IL: {m, n, k}); |
| 190 | return nativeSize; |
| 191 | } |
| 192 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
| 193 | int64_t writeRank = writeOp.getVectorType().getRank(); |
| 194 | if (writeRank < 2) |
| 195 | return std::nullopt; |
| 196 | SmallVector<int64_t> nativeSize(writeRank - 2, 1); |
| 197 | nativeSize.append(IL: {m, n}); |
| 198 | return nativeSize; |
| 199 | } |
| 200 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { |
| 201 | // Transfer read ops may need different shapes based on how they are being |
| 202 | // used. For simplicity just match the shape used by the extract strided op. |
| 203 | VectorType sliceType; |
| 204 | for (Operation *users : op->getUsers()) { |
| 205 | auto = dyn_cast<vector::ExtractStridedSliceOp>(users); |
| 206 | if (!extract) |
| 207 | return std::nullopt; |
| 208 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
| 209 | if (sliceType && sliceType != vecType) |
| 210 | return std::nullopt; |
| 211 | sliceType = vecType; |
| 212 | } |
| 213 | return llvm::to_vector(sliceType.getShape()); |
| 214 | } |
| 215 | if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { |
| 216 | if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { |
| 217 | // TODO: The condition for unrolling elementwise should be restricted |
| 218 | // only to operations that need unrolling (connected to the contract). |
| 219 | if (vecType.getRank() < 2) |
| 220 | return std::nullopt; |
| 221 | |
| 222 | // First check whether there is a slice to infer the shape from. This is |
| 223 | // required for cases where the accumulator type differs from the input |
| 224 | // types, in which case we will see an `arith.ext_` between the contract |
| 225 | // and transfer_read which needs to be unrolled. |
| 226 | VectorType sliceType; |
| 227 | for (Operation *users : op->getUsers()) { |
| 228 | auto = dyn_cast<vector::ExtractStridedSliceOp>(users); |
| 229 | if (!extract) |
| 230 | return std::nullopt; |
| 231 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
| 232 | if (sliceType && sliceType != vecType) |
| 233 | return std::nullopt; |
| 234 | sliceType = vecType; |
| 235 | } |
| 236 | if (sliceType) |
| 237 | return llvm::to_vector(sliceType.getShape()); |
| 238 | |
| 239 | // Else unroll for trailing elementwise. |
| 240 | SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); |
| 241 | // Map elementwise ops to the output shape. |
| 242 | nativeSize.append(IL: {m, n}); |
| 243 | return nativeSize; |
| 244 | } |
| 245 | } |
| 246 | return std::nullopt; |
| 247 | } |
| 248 | |
| 249 | void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( |
| 250 | RewritePatternSet &patterns) { |
| 251 | auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
| 252 | auto contract = dyn_cast<vector::ContractionOp>(op); |
| 253 | if (!contract) |
| 254 | return std::nullopt; |
| 255 | return gpuMmaUnrollOrder(contract); |
| 256 | }; |
| 257 | |
| 258 | int64_t m = getM(); |
| 259 | int64_t n = getN(); |
| 260 | int64_t k = getK(); |
| 261 | auto nativeShapeFn = |
| 262 | [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { |
| 263 | return getSubgroupMmaNativeVectorSize(op, m, n, k); |
| 264 | }; |
| 265 | vector::populateVectorUnrollPatterns( |
| 266 | patterns, vector::UnrollVectorOptions() |
| 267 | .setNativeShapeFn(nativeShapeFn) |
| 268 | .setUnrollTraversalOrderFn(unrollOrder)); |
| 269 | } |
| 270 | |
| 271 | //===----------------------------------------------------------------------===// |
| 272 | // EliminateBarriersOp |
| 273 | //===----------------------------------------------------------------------===// |
| 274 | |
| 275 | void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { |
| 276 | populateGpuEliminateBarriersPatterns(patterns); |
| 277 | } |
| 278 | |
| 279 | //===----------------------------------------------------------------------===// |
| 280 | // Block and thread mapping utilities. |
| 281 | //===----------------------------------------------------------------------===// |
| 282 | |
| 283 | namespace { |
| 284 | /// Local types used for mapping verification. |
| 285 | struct MappingKind {}; |
| 286 | struct BlockMappingKind : MappingKind {}; |
| 287 | struct ThreadMappingKind : MappingKind {}; |
| 288 | } // namespace |
| 289 | |
| 290 | static DiagnosedSilenceableFailure |
| 291 | definiteFailureHelper(std::optional<TransformOpInterface> transformOp, |
| 292 | Operation *target, const Twine &message) { |
| 293 | if (transformOp.has_value()) |
| 294 | return transformOp->emitDefiniteFailure() << message; |
| 295 | return emitDefiniteFailure(op: target, message); |
| 296 | } |
| 297 | |
| 298 | /// Check if given mapping attributes are one of the desired attributes |
| 299 | template <typename MappingKindType> |
| 300 | static DiagnosedSilenceableFailure |
| 301 | checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, |
| 302 | scf::ForallOp forallOp) { |
| 303 | if (!forallOp.getMapping().has_value()) { |
| 304 | return definiteFailureHelper(transformOp, forallOp, |
| 305 | "scf.forall op requires a mapping attribute" ); |
| 306 | } |
| 307 | |
| 308 | bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), |
| 309 | llvm::IsaPred<GPUBlockMappingAttr>); |
| 310 | bool hasWarpgroupMapping = llvm::any_of( |
| 311 | forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); |
| 312 | bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), |
| 313 | llvm::IsaPred<GPUWarpMappingAttr>); |
| 314 | bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), |
| 315 | llvm::IsaPred<GPUThreadMappingAttr>); |
| 316 | int64_t countMappingTypes = 0; |
| 317 | countMappingTypes += hasBlockMapping ? 1 : 0; |
| 318 | countMappingTypes += hasWarpgroupMapping ? 1 : 0; |
| 319 | countMappingTypes += hasWarpMapping ? 1 : 0; |
| 320 | countMappingTypes += hasThreadMapping ? 1 : 0; |
| 321 | if (countMappingTypes > 1) { |
| 322 | return definiteFailureHelper( |
| 323 | transformOp, forallOp, |
| 324 | "cannot mix different mapping types, use nesting" ); |
| 325 | } |
| 326 | if (std::is_same<MappingKindType, BlockMappingKind>::value && |
| 327 | !hasBlockMapping) { |
| 328 | return definiteFailureHelper( |
| 329 | transformOp, forallOp, |
| 330 | "scf.forall op requires a mapping attribute of kind 'block'" ); |
| 331 | } |
| 332 | if (std::is_same<MappingKindType, ThreadMappingKind>::value && |
| 333 | !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) { |
| 334 | return definiteFailureHelper(transformOp, forallOp, |
| 335 | "scf.forall op requires a mapping attribute " |
| 336 | "of kind 'thread' or 'warp'" ); |
| 337 | } |
| 338 | |
| 339 | DenseSet<Attribute> seen; |
| 340 | for (Attribute map : forallOp.getMapping()->getValue()) { |
| 341 | if (seen.contains(map)) { |
| 342 | return definiteFailureHelper( |
| 343 | transformOp, forallOp, |
| 344 | "duplicate attribute, cannot map different loops " |
| 345 | "to the same mapping id" ); |
| 346 | } |
| 347 | seen.insert(map); |
| 348 | } |
| 349 | |
| 350 | auto isLinear = [](Attribute a) { |
| 351 | return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); |
| 352 | }; |
| 353 | if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && |
| 354 | !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { |
| 355 | return definiteFailureHelper( |
| 356 | transformOp, forallOp, |
| 357 | "cannot mix linear and non-linear mapping modes" ); |
| 358 | } |
| 359 | |
| 360 | return DiagnosedSilenceableFailure::success(); |
| 361 | } |
| 362 | |
| 363 | template <typename MappingKindType> |
| 364 | static DiagnosedSilenceableFailure |
| 365 | verifyGpuMapping(std::optional<TransformOpInterface> transformOp, |
| 366 | scf::ForallOp forallOp) { |
| 367 | // Check the types of the mapping attributes match. |
| 368 | DiagnosedSilenceableFailure typeRes = |
| 369 | checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); |
| 370 | if (!typeRes.succeeded()) |
| 371 | return typeRes; |
| 372 | |
| 373 | // Perform other non-types verifications. |
| 374 | if (!forallOp.isNormalized()) |
| 375 | return definiteFailureHelper(transformOp, forallOp, |
| 376 | "unsupported non-normalized loops" ); |
| 377 | if (forallOp.getNumResults() > 0) |
| 378 | return definiteFailureHelper(transformOp, forallOp, |
| 379 | "only bufferized scf.forall can be mapped" ); |
| 380 | bool useLinearMapping = cast<DeviceMappingAttrInterface>( |
| 381 | forallOp.getMapping()->getValue().front()) |
| 382 | .isLinearMapping(); |
| 383 | // TODO: This would be more natural with support for Optional<EnumParameter> |
| 384 | // in GPUDeviceMappingAttr. |
| 385 | int64_t maxNumMappingsSupported = |
| 386 | useLinearMapping ? (getMaxEnumValForMappingId() - |
| 387 | static_cast<uint64_t>(MappingId::DimZ)) |
| 388 | : 3; |
| 389 | if (forallOp.getRank() > maxNumMappingsSupported) { |
| 390 | return definiteFailureHelper(transformOp, forallOp, |
| 391 | "scf.forall with rank > " ) |
| 392 | << maxNumMappingsSupported |
| 393 | << " does not lower for the specified mapping attribute type" ; |
| 394 | } |
| 395 | auto numParallelIterations = |
| 396 | getConstantIntValues(forallOp.getMixedUpperBound()); |
| 397 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
| 398 | return definiteFailureHelper( |
| 399 | transformOp, forallOp, |
| 400 | "requires statically sized, normalized forall op" ); |
| 401 | } |
| 402 | return DiagnosedSilenceableFailure::success(); |
| 403 | } |
| 404 | |
| 405 | /// Struct to return the result of the rewrite of a forall operation. |
| 406 | struct ForallRewriteResult { |
| 407 | SmallVector<int64_t> mappingSizes; |
| 408 | SmallVector<Value> mappingIds; |
| 409 | }; |
| 410 | |
| 411 | /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 412 | template <typename OpTy, typename OperationOrBlock> |
| 413 | static void |
| 414 | replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, |
| 415 | OperationOrBlock *parent, Value replacement, |
| 416 | ArrayRef<int64_t> availableMappingSizes) { |
| 417 | parent->walk([&](OpTy idOp) { |
| 418 | if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) |
| 419 | rewriter.replaceAllUsesWith(idOp.getResult(), replacement); |
| 420 | }); |
| 421 | } |
| 422 | |
| 423 | static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( |
| 424 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 425 | scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, |
| 426 | ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { |
| 427 | LDBG("--start rewriteOneForallCommonImpl" ); |
| 428 | |
| 429 | // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. |
| 430 | auto numParallelIterations = |
| 431 | getConstantIntValues(forallOp.getMixedUpperBound()); |
| 432 | assert(forallOp.isNormalized() && numParallelIterations.has_value() && |
| 433 | "requires statically sized, normalized forall op" ); |
| 434 | SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); |
| 435 | SetVector<Attribute> forallMappingAttrs; |
| 436 | forallMappingAttrs.insert_range(forallOp.getMapping()->getValue()); |
| 437 | auto comparator = [](Attribute a, Attribute b) -> bool { |
| 438 | return cast<DeviceMappingAttrInterface>(a).getMappingId() < |
| 439 | cast<DeviceMappingAttrInterface>(b).getMappingId(); |
| 440 | }; |
| 441 | |
| 442 | // Step 1.b. In the linear case, compute the max mapping to avoid needlessly |
| 443 | // mapping all dimensions. In the 3-D mapping case we need to map all |
| 444 | // dimensions. |
| 445 | DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( |
| 446 | *llvm::max_element(forallMappingAttrs, comparator)); |
| 447 | DeviceMappingAttrInterface maxLinearMapping; |
| 448 | if (maxMapping.isLinearMapping()) |
| 449 | maxLinearMapping = maxMapping; |
| 450 | for (auto attr : gpuIdBuilder.mappingAttributes) { |
| 451 | // If attr overflows, just skip. |
| 452 | if (maxLinearMapping && comparator(maxLinearMapping, attr)) |
| 453 | continue; |
| 454 | // Try to insert. If element was already present, just continue. |
| 455 | if (!forallMappingAttrs.insert(attr)) |
| 456 | continue; |
| 457 | // Otherwise, we have a new insertion without a size -> use size 1. |
| 458 | tmpMappingSizes.push_back(1); |
| 459 | } |
| 460 | LDBG("----tmpMappingSizes extracted from scf.forall op: " |
| 461 | << llvm::interleaved(tmpMappingSizes)); |
| 462 | |
| 463 | // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. |
| 464 | SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( |
| 465 | keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator); |
| 466 | LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes)); |
| 467 | LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs)); |
| 468 | |
| 469 | // Step 3. Generate the mappingIdOps using the provided generator. |
| 470 | Location loc = forallOp.getLoc(); |
| 471 | OpBuilder::InsertionGuard guard(rewriter); |
| 472 | rewriter.setInsertionPoint(forallOp); |
| 473 | SmallVector<int64_t> originalBasis(availableMappingSizes); |
| 474 | bool originalBasisWasProvided = !originalBasis.empty(); |
| 475 | if (!originalBasisWasProvided) { |
| 476 | originalBasis = forallMappingSizes; |
| 477 | while (originalBasis.size() < 3) |
| 478 | originalBasis.push_back(Elt: 1); |
| 479 | } |
| 480 | |
| 481 | IdBuilderResult builderResult = |
| 482 | gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); |
| 483 | |
| 484 | // Step 4. Map the induction variables to the mappingIdOps, this may involve |
| 485 | // a permutation. |
| 486 | SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; |
| 487 | IRMapping bvm; |
| 488 | for (auto [iv, dim] : llvm::zip_equal( |
| 489 | forallOp.getInductionVars(), |
| 490 | forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { |
| 491 | auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); |
| 492 | Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; |
| 493 | bvm.map(iv, peIdOp); |
| 494 | } |
| 495 | |
| 496 | // Step 5. If the originalBasis is already known, create conditionals to |
| 497 | // predicate the region. Otherwise, the current forall determines the |
| 498 | // originalBasis and no predication occurs. |
| 499 | Value predicate; |
| 500 | if (originalBasisWasProvided) { |
| 501 | SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes; |
| 502 | SmallVector<int64_t> availableMappingSizes = |
| 503 | builderResult.availableMappingSizes; |
| 504 | SmallVector<Value> activeIdOps = builderResult.activeIdOps; |
| 505 | LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes)); |
| 506 | LDBG("----availableMappingSizes: " |
| 507 | << llvm::interleaved(availableMappingSizes)); |
| 508 | LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps)); |
| 509 | for (auto [activeId, activeMappingSize, availableMappingSize] : |
| 510 | llvm::zip_equal(activeIdOps, activeMappingSizes, |
| 511 | availableMappingSizes)) { |
| 512 | if (activeMappingSize > availableMappingSize) { |
| 513 | return definiteFailureHelper( |
| 514 | transformOp, forallOp, |
| 515 | "Trying to map to fewer GPU threads than loop iterations but " |
| 516 | "overprovisioning is not yet supported. " |
| 517 | "Try additional tiling of the before mapping or map to more " |
| 518 | "threads." ); |
| 519 | } |
| 520 | if (activeMappingSize == availableMappingSize) |
| 521 | continue; |
| 522 | Value idx = |
| 523 | rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize); |
| 524 | Value tmpPredicate = rewriter.create<arith::CmpIOp>( |
| 525 | loc, arith::CmpIPredicate::ult, activeId, idx); |
| 526 | LDBG("----predicate: " << tmpPredicate); |
| 527 | predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, |
| 528 | tmpPredicate) |
| 529 | : tmpPredicate; |
| 530 | } |
| 531 | } |
| 532 | |
| 533 | // Step 6. Move the body of forallOp. |
| 534 | // Erase the terminator first, it will not be used. |
| 535 | rewriter.eraseOp(op: forallOp.getTerminator()); |
| 536 | Block *targetBlock; |
| 537 | Block::iterator insertionPoint; |
| 538 | if (predicate) { |
| 539 | // Step 6.a. If predicated, move at the beginning. |
| 540 | auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, |
| 541 | /*withElseRegion=*/false); |
| 542 | targetBlock = ifOp.thenBlock(); |
| 543 | insertionPoint = ifOp.thenBlock()->begin(); |
| 544 | } else { |
| 545 | // Step 6.b. Otherwise, move inline just at the rewriter insertion |
| 546 | // point. |
| 547 | targetBlock = forallOp->getBlock(); |
| 548 | insertionPoint = rewriter.getInsertionPoint(); |
| 549 | } |
| 550 | Block &sourceBlock = forallOp.getRegion().front(); |
| 551 | targetBlock->getOperations().splice(where: insertionPoint, |
| 552 | L2&: sourceBlock.getOperations()); |
| 553 | |
| 554 | // Step 7. RAUW indices. |
| 555 | for (Value loopIndex : forallOp.getInductionVars()) { |
| 556 | Value threadIdx = bvm.lookup(loopIndex); |
| 557 | rewriter.replaceAllUsesWith(loopIndex, threadIdx); |
| 558 | } |
| 559 | |
| 560 | // Step 8. Erase old op. |
| 561 | rewriter.eraseOp(op: forallOp); |
| 562 | |
| 563 | LDBG("----result forallMappingSizes: " |
| 564 | << llvm::interleaved(forallMappingSizes)); |
| 565 | LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps)); |
| 566 | |
| 567 | result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps}; |
| 568 | return DiagnosedSilenceableFailure::success(); |
| 569 | } |
| 570 | |
| 571 | //===----------------------------------------------------------------------===// |
| 572 | // MapForallToBlocks |
| 573 | //===----------------------------------------------------------------------===// |
| 574 | |
| 575 | DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( |
| 576 | RewriterBase &rewriter, TransformOpInterface transformOp, |
| 577 | scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, |
| 578 | const GpuIdBuilder &gpuIdBuilder) { |
| 579 | LDBG("Start mapForallToBlocksImpl" ); |
| 580 | |
| 581 | { |
| 582 | // GPU-specific verifications. There is no better place to anchor |
| 583 | // those right now: the ForallOp is target-independent and the transform |
| 584 | // op does not apply to individual ForallOp. |
| 585 | DiagnosedSilenceableFailure diag = |
| 586 | verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); |
| 587 | if (!diag.succeeded()) |
| 588 | return diag; |
| 589 | } |
| 590 | |
| 591 | Location loc = forallOp.getLoc(); |
| 592 | Block *parentBlock = forallOp->getBlock(); |
| 593 | Value zero; |
| 594 | { |
| 595 | // Create an early zero index value for replacements and immediately reset |
| 596 | // the insertion point. |
| 597 | OpBuilder::InsertionGuard guard(rewriter); |
| 598 | rewriter.setInsertionPointToStart(parentBlock); |
| 599 | zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 600 | } |
| 601 | |
| 602 | ForallRewriteResult rewriteResult; |
| 603 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
| 604 | rewriter, transformOp, forallOp, |
| 605 | /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder); |
| 606 | |
| 607 | // Return if anything goes wrong, use silenceable failure as a match |
| 608 | // failure. |
| 609 | if (!diag.succeeded()) |
| 610 | return diag; |
| 611 | |
| 612 | // If gridDims was not provided already, set it from the return. |
| 613 | if (gridDims.empty()) { |
| 614 | gridDims = rewriteResult.mappingSizes; |
| 615 | while (gridDims.size() < 3) |
| 616 | gridDims.push_back(Elt: 1); |
| 617 | } |
| 618 | assert(gridDims.size() == 3 && "Need 3-D gridDims" ); |
| 619 | |
| 620 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 621 | // Here, the result of mapping determines the available mapping sizes. |
| 622 | replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, |
| 623 | rewriteResult.mappingSizes); |
| 624 | |
| 625 | return DiagnosedSilenceableFailure::success(); |
| 626 | } |
| 627 | |
| 628 | DiagnosedSilenceableFailure |
| 629 | mlir::transform::gpu::findTopLevelForallOp(Operation *target, |
| 630 | scf::ForallOp &topLevelForallOp, |
| 631 | TransformOpInterface transformOp) { |
| 632 | auto walkResult = target->walk([&](scf::ForallOp forallOp) { |
| 633 | if (forallOp->getParentOfType<scf::ForallOp>()) |
| 634 | return WalkResult::advance(); |
| 635 | if (topLevelForallOp) |
| 636 | // TODO: Handle multiple forall if they are independent. |
| 637 | return WalkResult::interrupt(); |
| 638 | topLevelForallOp = forallOp; |
| 639 | return WalkResult::advance(); |
| 640 | }); |
| 641 | |
| 642 | if (walkResult.wasInterrupted() || !topLevelForallOp) |
| 643 | return transformOp.emitSilenceableError() |
| 644 | << "could not find a unique topLevel scf.forall" ; |
| 645 | return DiagnosedSilenceableFailure::success(); |
| 646 | } |
| 647 | |
| 648 | DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( |
| 649 | transform::TransformRewriter &rewriter, Operation *target, |
| 650 | ApplyToEachResultList &results, transform::TransformState &state) { |
| 651 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
| 652 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
| 653 | |
| 654 | if (!getGenerateGpuLaunch() && !gpuLaunch) { |
| 655 | DiagnosedSilenceableFailure diag = |
| 656 | emitSilenceableError() |
| 657 | << "Given target is not gpu.launch, set `generate_gpu_launch` " |
| 658 | "attribute" ; |
| 659 | diag.attachNote(target->getLoc()) << "when applied to this payload op" ; |
| 660 | return diag; |
| 661 | } |
| 662 | |
| 663 | scf::ForallOp topLevelForallOp; |
| 664 | DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( |
| 665 | target, topLevelForallOp, transformOp); |
| 666 | if (!diag.succeeded()) { |
| 667 | diag.attachNote(target->getLoc()) << "when applied to this payload op" ; |
| 668 | return diag; |
| 669 | } |
| 670 | assert(topLevelForallOp && "expect an scf.forall" ); |
| 671 | |
| 672 | SmallVector<int64_t> gridDims{getGridDims()}; |
| 673 | if (!getGenerateGpuLaunch() && gridDims.size() != 3) |
| 674 | return transformOp.emitDefiniteFailure("transform require size-3 mapping" ); |
| 675 | |
| 676 | OpBuilder::InsertionGuard guard(rewriter); |
| 677 | rewriter.setInsertionPoint(topLevelForallOp); |
| 678 | |
| 679 | // Generate gpu launch here and move the forall inside |
| 680 | if (getGenerateGpuLaunch()) { |
| 681 | DiagnosedSilenceableFailure diag = |
| 682 | createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); |
| 683 | if (!diag.succeeded()) |
| 684 | return diag; |
| 685 | |
| 686 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
| 687 | Operation *newForallOp = rewriter.clone(*topLevelForallOp); |
| 688 | rewriter.eraseOp(topLevelForallOp); |
| 689 | topLevelForallOp = cast<scf::ForallOp>(newForallOp); |
| 690 | } |
| 691 | |
| 692 | // The BlockIdBuilder adapts to whatever is thrown at it. |
| 693 | bool useLinearMapping = false; |
| 694 | if (topLevelForallOp.getMapping()) { |
| 695 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
| 696 | topLevelForallOp.getMapping()->getValue().front()); |
| 697 | useLinearMapping = mappingAttr.isLinearMapping(); |
| 698 | } |
| 699 | GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); |
| 700 | |
| 701 | diag = mlir::transform::gpu::mapForallToBlocksImpl( |
| 702 | rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); |
| 703 | if (!diag.succeeded()) |
| 704 | return diag; |
| 705 | |
| 706 | // Set the GPU launch configuration for the grid dims late, this is |
| 707 | // subject to IR inspection. |
| 708 | diag = alterGpuLaunch(rewriter, gpuLaunch, |
| 709 | cast<TransformOpInterface>(getOperation()), gridDims[0], |
| 710 | gridDims[1], gridDims[2]); |
| 711 | |
| 712 | results.push_back(gpuLaunch); |
| 713 | return diag; |
| 714 | } |
| 715 | |
| 716 | LogicalResult transform::MapForallToBlocks::verify() { |
| 717 | if (!getGridDims().empty() && getGridDims().size() != 3) { |
| 718 | return emitOpError() << "transform requires empty or size-3 grid_dims" ; |
| 719 | } |
| 720 | return success(); |
| 721 | } |
| 722 | |
| 723 | //===----------------------------------------------------------------------===// |
| 724 | // MapNestedForallToThreads |
| 725 | //===----------------------------------------------------------------------===// |
| 726 | |
| 727 | static DiagnosedSilenceableFailure checkMappingSpec( |
| 728 | std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, |
| 729 | ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, |
| 730 | int factor, bool useLinearMapping = false) { |
| 731 | if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { |
| 732 | auto diag = definiteFailureHelper( |
| 733 | transformOp, forallOp, |
| 734 | Twine("3-D mapping: size of threadIdx.x must be a multiple of " ) + |
| 735 | Twine(factor)); |
| 736 | return diag; |
| 737 | } |
| 738 | if (computeProduct(basis: numParallelIterations) * factor > |
| 739 | computeProduct(basis: blockOrGridSizes)) { |
| 740 | auto diag = definiteFailureHelper( |
| 741 | transformOp, forallOp, |
| 742 | Twine("the number of required parallel resources (blocks or " |
| 743 | "threads) " ) + |
| 744 | Twine(computeProduct(basis: numParallelIterations) * factor) + |
| 745 | " overflows the number of available resources " + |
| 746 | Twine(computeProduct(basis: blockOrGridSizes))); |
| 747 | return diag; |
| 748 | } |
| 749 | return DiagnosedSilenceableFailure::success(); |
| 750 | } |
| 751 | |
| 752 | static DiagnosedSilenceableFailure |
| 753 | getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, |
| 754 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, |
| 755 | int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { |
| 756 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
| 757 | forallOp.getMapping()->getValue().front()); |
| 758 | bool useLinearMapping = mappingAttr.isLinearMapping(); |
| 759 | |
| 760 | // Sanity checks that may result in runtime verification errors. |
| 761 | auto numParallelIterations = |
| 762 | getConstantIntValues((forallOp.getMixedUpperBound())); |
| 763 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
| 764 | return definiteFailureHelper( |
| 765 | transformOp, forallOp, |
| 766 | "requires statically sized, normalized forall op" ); |
| 767 | } |
| 768 | int64_t factor = 1; |
| 769 | if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) { |
| 770 | factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; |
| 771 | } else if (isa<GPUWarpMappingAttr>(mappingAttr)) { |
| 772 | factor = warpSize; |
| 773 | } |
| 774 | DiagnosedSilenceableFailure diag = |
| 775 | checkMappingSpec(transformOp, forallOp, numParallelIterations.value(), |
| 776 | blockSizes, factor, useLinearMapping); |
| 777 | if (!diag.succeeded()) |
| 778 | return diag; |
| 779 | |
| 780 | // Start mapping. |
| 781 | MLIRContext *ctx = forallOp.getContext(); |
| 782 | gpuIdBuilder = |
| 783 | TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) |
| 784 | .Case([&](GPUWarpgroupMappingAttr) { |
| 785 | return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); |
| 786 | }) |
| 787 | .Case([&](GPUWarpMappingAttr) { |
| 788 | return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); |
| 789 | }) |
| 790 | .Case([&](GPUThreadMappingAttr) { |
| 791 | return GpuThreadIdBuilder(ctx, useLinearMapping); |
| 792 | }) |
| 793 | .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { |
| 794 | llvm_unreachable("unknown mapping attribute" ); |
| 795 | }); |
| 796 | return DiagnosedSilenceableFailure::success(); |
| 797 | } |
| 798 | |
| 799 | DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( |
| 800 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 801 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, |
| 802 | bool syncAfterDistribute) { |
| 803 | |
| 804 | { |
| 805 | // GPU-specific verifications. There is no better place to anchor |
| 806 | // those right now: the ForallOp is target-independent and the transform |
| 807 | // op does not apply to individual ForallOp. |
| 808 | DiagnosedSilenceableFailure diag = |
| 809 | verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); |
| 810 | if (!diag.succeeded()) |
| 811 | return diag; |
| 812 | } |
| 813 | |
| 814 | GpuIdBuilder gpuIdBuilder; |
| 815 | { |
| 816 | // Try to construct the id builder, if it fails, return. |
| 817 | DiagnosedSilenceableFailure diag = getThreadIdBuilder( |
| 818 | transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); |
| 819 | if (!diag.succeeded()) |
| 820 | return diag; |
| 821 | } |
| 822 | |
| 823 | Location loc = forallOp.getLoc(); |
| 824 | OpBuilder::InsertionGuard g(rewriter); |
| 825 | // Insert after to allow for syncthreads after `forall` is erased. |
| 826 | rewriter.setInsertionPointAfter(forallOp); |
| 827 | ForallRewriteResult rewriteResult; |
| 828 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
| 829 | rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder); |
| 830 | if (!diag.succeeded()) |
| 831 | return diag; |
| 832 | // Add a syncthreads if needed. TODO: warpsync |
| 833 | if (syncAfterDistribute) |
| 834 | rewriter.create<BarrierOp>(loc); |
| 835 | |
| 836 | return DiagnosedSilenceableFailure::success(); |
| 837 | } |
| 838 | |
| 839 | DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( |
| 840 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 841 | Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, |
| 842 | bool syncAfterDistribute) { |
| 843 | LDBG("Start mapNestedForallToThreadsImpl" ); |
| 844 | if (blockDims.size() != 3) { |
| 845 | return definiteFailureHelper(transformOp, target, |
| 846 | message: "requires size-3 thread mapping" ); |
| 847 | } |
| 848 | |
| 849 | // Create an early zero index value for replacements. |
| 850 | Location loc = target->getLoc(); |
| 851 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 852 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
| 853 | WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { |
| 854 | diag = mlir::transform::gpu::mapOneForallToThreadsImpl( |
| 855 | rewriter, transformOp, forallOp, blockDims, warpSize, |
| 856 | syncAfterDistribute); |
| 857 | if (diag.isDefiniteFailure()) |
| 858 | return WalkResult::interrupt(); |
| 859 | if (diag.succeeded()) |
| 860 | return WalkResult::skip(); |
| 861 | return WalkResult::advance(); |
| 862 | }); |
| 863 | if (walkResult.wasInterrupted()) |
| 864 | return diag; |
| 865 | |
| 866 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 867 | // Here, the result of mapping determines the available mapping sizes. |
| 868 | replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, |
| 869 | blockDims); |
| 870 | |
| 871 | return DiagnosedSilenceableFailure::success(); |
| 872 | } |
| 873 | |
| 874 | DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( |
| 875 | transform::TransformRewriter &rewriter, Operation *target, |
| 876 | ApplyToEachResultList &results, TransformState &state) { |
| 877 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
| 878 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
| 879 | |
| 880 | // Basic high-level verifications. |
| 881 | if (!gpuLaunch) |
| 882 | return emitSilenceableError() << "Given target is not a gpu.launch" ; |
| 883 | |
| 884 | // Mapping to block ids. |
| 885 | SmallVector<int64_t> blockDims{getBlockDims()}; |
| 886 | DiagnosedSilenceableFailure diag = |
| 887 | checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, |
| 888 | blockDims[0], blockDims[1], blockDims[2]); |
| 889 | if (diag.isSilenceableFailure()) { |
| 890 | diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large" ; |
| 891 | return diag; |
| 892 | } |
| 893 | |
| 894 | // Set the GPU launch configuration for the block dims early, this is not |
| 895 | // subject to IR inspection. |
| 896 | diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, |
| 897 | std::nullopt, std::nullopt, blockDims[0], blockDims[1], |
| 898 | blockDims[2]); |
| 899 | |
| 900 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
| 901 | diag = |
| 902 | mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, |
| 903 | getWarpSize(), getSyncAfterDistribute()); |
| 904 | |
| 905 | results.push_back(gpuLaunch.getOperation()); |
| 906 | return diag; |
| 907 | } |
| 908 | |
| 909 | //===----------------------------------------------------------------------===// |
| 910 | // Transform op registration |
| 911 | //===----------------------------------------------------------------------===// |
| 912 | |
| 913 | namespace { |
| 914 | /// Registers new ops and declares PDL as dependent dialect since the |
| 915 | /// additional ops are using PDL types for operands and results. |
| 916 | class GPUTransformDialectExtension |
| 917 | : public transform::TransformDialectExtension< |
| 918 | GPUTransformDialectExtension> { |
| 919 | public: |
| 920 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) |
| 921 | |
| 922 | GPUTransformDialectExtension() { |
| 923 | declareGeneratedDialect<GPUDialect>(); |
| 924 | declareGeneratedDialect<amdgpu::AMDGPUDialect>(); |
| 925 | declareGeneratedDialect<arith::ArithDialect>(); |
| 926 | declareGeneratedDialect<scf::SCFDialect>(); |
| 927 | registerTransformOps< |
| 928 | #define GET_OP_LIST |
| 929 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
| 930 | >(); |
| 931 | } |
| 932 | }; |
| 933 | } // namespace |
| 934 | |
| 935 | #define GET_OP_CLASSES |
| 936 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
| 937 | |
| 938 | void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 939 | registry.addExtensions<GPUTransformDialectExtension>(); |
| 940 | } |
| 941 | |