| 1 | //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 12 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
| 13 | #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 15 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 18 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
| 19 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 20 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 21 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| 22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 23 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 26 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 27 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 29 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 30 | #include "mlir/IR/AffineExpr.h" |
| 31 | #include "mlir/IR/Builders.h" |
| 32 | #include "mlir/IR/BuiltinAttributes.h" |
| 33 | #include "mlir/IR/IRMapping.h" |
| 34 | #include "mlir/IR/MLIRContext.h" |
| 35 | #include "mlir/IR/OpDefinition.h" |
| 36 | #include "mlir/IR/Visitors.h" |
| 37 | #include "mlir/Support/LLVM.h" |
| 38 | #include "mlir/Transforms/DialectConversion.h" |
| 39 | #include "llvm/ADT/STLExtras.h" |
| 40 | #include "llvm/ADT/SmallVector.h" |
| 41 | #include "llvm/ADT/TypeSwitch.h" |
| 42 | #include "llvm/Support/Debug.h" |
| 43 | #include "llvm/Support/ErrorHandling.h" |
| 44 | #include "llvm/Support/InterleavedRange.h" |
| 45 | #include "llvm/Support/LogicalResult.h" |
| 46 | #include <type_traits> |
| 47 | |
| 48 | using namespace mlir; |
| 49 | using namespace mlir::gpu; |
| 50 | using namespace mlir::transform; |
| 51 | using namespace mlir::transform::gpu; |
| 52 | |
| 53 | #define DEBUG_TYPE "gpu-transforms" |
| 54 | #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" |
| 55 | |
| 56 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 57 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 58 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
| 59 | |
| 60 | //===----------------------------------------------------------------------===// |
| 61 | // Apply...ConversionPatternsOp |
| 62 | //===----------------------------------------------------------------------===// |
| 63 | |
| 64 | void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( |
| 65 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 66 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 67 | // NVVM uses alloca in the default address space to represent private |
| 68 | // memory allocations, so drop private annotations. NVVM uses address |
| 69 | // space 3 for shared memory. NVVM uses the default address space to |
| 70 | // represent global memory. |
| 71 | // Used in populateGpuToNVVMConversionPatternsso attaching here for now. |
| 72 | // TODO: We should have a single to_nvvm_type_converter. |
| 73 | populateGpuMemorySpaceAttributeConversions( |
| 74 | typeConverter&: llvmTypeConverter, mapping: [](AddressSpace space) -> unsigned { |
| 75 | switch (space) { |
| 76 | case AddressSpace::Global: |
| 77 | return static_cast<unsigned>( |
| 78 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
| 79 | case AddressSpace::Workgroup: |
| 80 | return static_cast<unsigned>( |
| 81 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
| 82 | case AddressSpace::Private: |
| 83 | return 0; |
| 84 | } |
| 85 | llvm_unreachable("unknown address space enum value" ); |
| 86 | return 0; |
| 87 | }); |
| 88 | // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. |
| 89 | // TODO: We should have a single to_nvvm_type_converter. |
| 90 | llvmTypeConverter.addConversion( |
| 91 | callback: [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); |
| 92 | // Set higher benefit, so patterns will run before generic LLVM lowering. |
| 93 | populateGpuToNVVMConversionPatterns(converter: llvmTypeConverter, patterns, |
| 94 | benefit: getBenefit()); |
| 95 | } |
| 96 | |
| 97 | LogicalResult |
| 98 | transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
| 99 | transform::TypeConverterBuilderOpInterface builder) { |
| 100 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 101 | return emitOpError(message: "expected LLVMTypeConverter" ); |
| 102 | return success(); |
| 103 | } |
| 104 | |
| 105 | void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( |
| 106 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 107 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 108 | populateGpuWMMAToNVVMConversionPatterns(converter: llvmTypeConverter, patterns); |
| 109 | } |
| 110 | |
| 111 | LogicalResult |
| 112 | transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( |
| 113 | transform::TypeConverterBuilderOpInterface builder) { |
| 114 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 115 | return emitOpError(message: "expected LLVMTypeConverter" ); |
| 116 | return success(); |
| 117 | } |
| 118 | |
| 119 | void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
| 120 | populatePatterns(TypeConverter &typeConverter, |
| 121 | RewritePatternSet &patterns) { |
| 122 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 123 | populateGpuSubgroupReduceOpLoweringPattern(converter: llvmTypeConverter, patterns); |
| 124 | } |
| 125 | |
| 126 | LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
| 127 | verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { |
| 128 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 129 | return emitOpError(message: "expected LLVMTypeConverter" ); |
| 130 | return success(); |
| 131 | } |
| 132 | |
| 133 | void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns( |
| 134 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 135 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 136 | populateGpuMemorySpaceAttributeConversions( |
| 137 | typeConverter&: llvmTypeConverter, mapping: [](AddressSpace space) { |
| 138 | switch (space) { |
| 139 | case AddressSpace::Global: |
| 140 | return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace; |
| 141 | case AddressSpace::Workgroup: |
| 142 | return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace; |
| 143 | case AddressSpace::Private: |
| 144 | return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace; |
| 145 | } |
| 146 | llvm_unreachable("unknown address space enum value" ); |
| 147 | }); |
| 148 | FailureOr<amdgpu::Chipset> maybeChipset = |
| 149 | amdgpu::Chipset::parse(name: getChipset()); |
| 150 | assert(llvm::succeeded(maybeChipset) && "expected valid chipset" ); |
| 151 | populateGpuToROCDLConversionPatterns( |
| 152 | converter: llvmTypeConverter, patterns, runtime: mlir::gpu::amd::Runtime::HIP, chipset: *maybeChipset); |
| 153 | } |
| 154 | |
| 155 | LogicalResult |
| 156 | transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter( |
| 157 | transform::TypeConverterBuilderOpInterface builder) { |
| 158 | FailureOr<amdgpu::Chipset> maybeChipset = |
| 159 | amdgpu::Chipset::parse(name: getChipset()); |
| 160 | if (failed(Result: maybeChipset)) { |
| 161 | return emitOpError(message: "Invalid chipset name: " + getChipset()); |
| 162 | } |
| 163 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 164 | return emitOpError(message: "expected LLVMTypeConverter" ); |
| 165 | return success(); |
| 166 | } |
| 167 | |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | // Apply...PatternsOp |
| 170 | //===----------------------------------------------------------------------===//s |
| 171 | |
| 172 | void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { |
| 173 | populateGpuRewritePatterns(patterns); |
| 174 | } |
| 175 | |
| 176 | void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns( |
| 177 | RewritePatternSet &patterns) { |
| 178 | populateGpuPromoteShuffleToAMDGPUPatterns(patterns); |
| 179 | } |
| 180 | |
| 181 | //===----------------------------------------------------------------------===// |
| 182 | // ApplyUnrollVectorsSubgroupMmaOp |
| 183 | //===----------------------------------------------------------------------===// |
| 184 | |
| 185 | /// Pick an unrolling order that will allow tensorcore operation to reuse LHS |
| 186 | /// register. |
| 187 | static std::optional<SmallVector<int64_t>> |
| 188 | gpuMmaUnrollOrder(vector::ContractionOp contract) { |
| 189 | SmallVector<int64_t> order; |
| 190 | // First make reduction the outer dimensions. |
| 191 | for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) { |
| 192 | if (vector::isReductionIterator(attr: iter)) { |
| 193 | order.push_back(Elt: index); |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | llvm::SmallDenseSet<int64_t> dims; |
| 198 | for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { |
| 199 | dims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition()); |
| 200 | } |
| 201 | // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. |
| 202 | for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) { |
| 203 | if (vector::isParallelIterator(attr: iter) && dims.count(V: index)) { |
| 204 | order.push_back(Elt: index); |
| 205 | } |
| 206 | } |
| 207 | // Then the remaining parallel loops. |
| 208 | for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) { |
| 209 | if (vector::isParallelIterator(attr: iter) && !dims.count(V: index)) { |
| 210 | order.push_back(Elt: index); |
| 211 | } |
| 212 | } |
| 213 | return order; |
| 214 | } |
| 215 | |
| 216 | /// Returns the target vector size for the target operation based on the native |
| 217 | /// vector size specified with `m`, `n`, and `k`. |
| 218 | static std::optional<SmallVector<int64_t>> |
| 219 | getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { |
| 220 | if (auto contract = dyn_cast<vector::ContractionOp>(Val: op)) { |
| 221 | int64_t contractRank = contract.getIteratorTypes().size(); |
| 222 | if (contractRank < 3) |
| 223 | return std::nullopt; |
| 224 | SmallVector<int64_t> nativeSize(contractRank - 3, 1); |
| 225 | nativeSize.append(IL: {m, n, k}); |
| 226 | return nativeSize; |
| 227 | } |
| 228 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(Val: op)) { |
| 229 | int64_t writeRank = writeOp.getVectorType().getRank(); |
| 230 | if (writeRank < 2) |
| 231 | return std::nullopt; |
| 232 | SmallVector<int64_t> nativeSize(writeRank - 2, 1); |
| 233 | nativeSize.append(IL: {m, n}); |
| 234 | return nativeSize; |
| 235 | } |
| 236 | if (auto readOp = dyn_cast<vector::TransferReadOp>(Val: op)) { |
| 237 | // Transfer read ops may need different shapes based on how they are being |
| 238 | // used. For simplicity just match the shape used by the extract strided op. |
| 239 | VectorType sliceType; |
| 240 | for (Operation *users : op->getUsers()) { |
| 241 | auto = dyn_cast<vector::ExtractStridedSliceOp>(Val: users); |
| 242 | if (!extract) |
| 243 | return std::nullopt; |
| 244 | auto vecType = cast<VectorType>(Val: extract.getResult().getType()); |
| 245 | if (sliceType && sliceType != vecType) |
| 246 | return std::nullopt; |
| 247 | sliceType = vecType; |
| 248 | } |
| 249 | return llvm::to_vector(Range: sliceType.getShape()); |
| 250 | } |
| 251 | if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { |
| 252 | if (auto vecType = dyn_cast<VectorType>(Val: op->getResultTypes()[0])) { |
| 253 | // TODO: The condition for unrolling elementwise should be restricted |
| 254 | // only to operations that need unrolling (connected to the contract). |
| 255 | if (vecType.getRank() < 2) |
| 256 | return std::nullopt; |
| 257 | |
| 258 | // First check whether there is a slice to infer the shape from. This is |
| 259 | // required for cases where the accumulator type differs from the input |
| 260 | // types, in which case we will see an `arith.ext_` between the contract |
| 261 | // and transfer_read which needs to be unrolled. |
| 262 | VectorType sliceType; |
| 263 | for (Operation *users : op->getUsers()) { |
| 264 | auto = dyn_cast<vector::ExtractStridedSliceOp>(Val: users); |
| 265 | if (!extract) |
| 266 | return std::nullopt; |
| 267 | auto vecType = cast<VectorType>(Val: extract.getResult().getType()); |
| 268 | if (sliceType && sliceType != vecType) |
| 269 | return std::nullopt; |
| 270 | sliceType = vecType; |
| 271 | } |
| 272 | if (sliceType) |
| 273 | return llvm::to_vector(Range: sliceType.getShape()); |
| 274 | |
| 275 | // Else unroll for trailing elementwise. |
| 276 | SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); |
| 277 | // Map elementwise ops to the output shape. |
| 278 | nativeSize.append(IL: {m, n}); |
| 279 | return nativeSize; |
| 280 | } |
| 281 | } |
| 282 | return std::nullopt; |
| 283 | } |
| 284 | |
| 285 | void transform::ApplyUnrollVectorsSubgroupMmaOp::( |
| 286 | RewritePatternSet &patterns) { |
| 287 | auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
| 288 | auto contract = dyn_cast<vector::ContractionOp>(Val: op); |
| 289 | if (!contract) |
| 290 | return std::nullopt; |
| 291 | return gpuMmaUnrollOrder(contract); |
| 292 | }; |
| 293 | |
| 294 | int64_t m = getM(); |
| 295 | int64_t n = getN(); |
| 296 | int64_t k = getK(); |
| 297 | auto nativeShapeFn = |
| 298 | [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { |
| 299 | return getSubgroupMmaNativeVectorSize(op, m, n, k); |
| 300 | }; |
| 301 | vector::populateVectorUnrollPatterns( |
| 302 | patterns, options: vector::UnrollVectorOptions() |
| 303 | .setNativeShapeFn(nativeShapeFn) |
| 304 | .setUnrollTraversalOrderFn(unrollOrder)); |
| 305 | } |
| 306 | |
| 307 | //===----------------------------------------------------------------------===// |
| 308 | // EliminateBarriersOp |
| 309 | //===----------------------------------------------------------------------===// |
| 310 | |
| 311 | void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { |
| 312 | populateGpuEliminateBarriersPatterns(patterns); |
| 313 | } |
| 314 | |
| 315 | //===----------------------------------------------------------------------===// |
| 316 | // Block and thread mapping utilities. |
| 317 | //===----------------------------------------------------------------------===// |
| 318 | |
| 319 | namespace { |
| 320 | /// Local types used for mapping verification. |
| 321 | struct MappingKind {}; |
| 322 | struct BlockMappingKind : MappingKind {}; |
| 323 | struct ThreadMappingKind : MappingKind {}; |
| 324 | } // namespace |
| 325 | |
| 326 | static DiagnosedSilenceableFailure |
| 327 | definiteFailureHelper(std::optional<TransformOpInterface> transformOp, |
| 328 | Operation *target, const Twine &message) { |
| 329 | if (transformOp.has_value()) |
| 330 | return transformOp->emitDefiniteFailure() << message; |
| 331 | return emitDefiniteFailure(op: target, message); |
| 332 | } |
| 333 | |
| 334 | /// Check if given mapping attributes are one of the desired attributes |
| 335 | template <typename MappingKindType> |
| 336 | static DiagnosedSilenceableFailure |
| 337 | checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, |
| 338 | scf::ForallOp forallOp) { |
| 339 | if (!forallOp.getMapping().has_value()) { |
| 340 | return definiteFailureHelper(transformOp, target: forallOp, |
| 341 | message: "scf.forall op requires a mapping attribute" ); |
| 342 | } |
| 343 | |
| 344 | bool hasBlockMapping = llvm::any_of(Range: forallOp.getMapping().value(), |
| 345 | P: llvm::IsaPred<GPUBlockMappingAttr>); |
| 346 | bool hasWarpgroupMapping = llvm::any_of( |
| 347 | Range: forallOp.getMapping().value(), P: llvm::IsaPred<GPUWarpgroupMappingAttr>); |
| 348 | bool hasWarpMapping = llvm::any_of(Range: forallOp.getMapping().value(), |
| 349 | P: llvm::IsaPred<GPUWarpMappingAttr>); |
| 350 | bool hasThreadMapping = llvm::any_of(Range: forallOp.getMapping().value(), |
| 351 | P: llvm::IsaPred<GPUThreadMappingAttr>); |
| 352 | bool hasLaneMapping = llvm::any_of(Range: forallOp.getMapping().value(), |
| 353 | P: llvm::IsaPred<GPULaneMappingAttr>); |
| 354 | int64_t countMappingTypes = 0; |
| 355 | countMappingTypes += hasBlockMapping ? 1 : 0; |
| 356 | countMappingTypes += hasWarpgroupMapping ? 1 : 0; |
| 357 | countMappingTypes += hasWarpMapping ? 1 : 0; |
| 358 | countMappingTypes += hasThreadMapping ? 1 : 0; |
| 359 | countMappingTypes += hasLaneMapping ? 1 : 0; |
| 360 | if (countMappingTypes > 1) { |
| 361 | return definiteFailureHelper( |
| 362 | transformOp, target: forallOp, |
| 363 | message: "cannot mix different mapping types, use nesting" ); |
| 364 | } |
| 365 | if (std::is_same<MappingKindType, BlockMappingKind>::value && |
| 366 | !hasBlockMapping) { |
| 367 | return definiteFailureHelper( |
| 368 | transformOp, target: forallOp, |
| 369 | message: "scf.forall op requires a mapping attribute of kind 'block'" ); |
| 370 | } |
| 371 | if (std::is_same<MappingKindType, ThreadMappingKind>::value && |
| 372 | !hasLaneMapping && !hasThreadMapping && !hasWarpMapping && |
| 373 | !hasWarpgroupMapping) { |
| 374 | return definiteFailureHelper(transformOp, target: forallOp, |
| 375 | message: "scf.forall op requires a mapping attribute " |
| 376 | "of kind 'thread' or 'warp'" ); |
| 377 | } |
| 378 | |
| 379 | DenseSet<Attribute> seen; |
| 380 | for (Attribute map : forallOp.getMapping()->getValue()) { |
| 381 | if (seen.contains(V: map)) { |
| 382 | return definiteFailureHelper( |
| 383 | transformOp, target: forallOp, |
| 384 | message: "duplicate attribute, cannot map different loops " |
| 385 | "to the same mapping id" ); |
| 386 | } |
| 387 | seen.insert(V: map); |
| 388 | } |
| 389 | |
| 390 | auto isLinear = [](DeviceMappingAttrInterface attr) { |
| 391 | return attr.isLinearMapping(); |
| 392 | }; |
| 393 | if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) && |
| 394 | !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) { |
| 395 | return definiteFailureHelper( |
| 396 | transformOp, target: forallOp, |
| 397 | message: "cannot mix linear and non-linear mapping modes" ); |
| 398 | } |
| 399 | |
| 400 | FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = |
| 401 | forallOp.getDeviceMaskingAttr(); |
| 402 | if (succeeded(Result: maybeMaskingAttr) && *maybeMaskingAttr && |
| 403 | !forallOp.usesLinearMapping()) { |
| 404 | return definiteFailureHelper( |
| 405 | transformOp, target: forallOp, |
| 406 | message: "device masking is only available in linear mapping mode" ); |
| 407 | } |
| 408 | |
| 409 | return DiagnosedSilenceableFailure::success(); |
| 410 | } |
| 411 | |
| 412 | template <typename MappingKindType> |
| 413 | static DiagnosedSilenceableFailure |
| 414 | verifyGpuMapping(std::optional<TransformOpInterface> transformOp, |
| 415 | scf::ForallOp forallOp) { |
| 416 | // Check the types of the mapping attributes match. |
| 417 | DiagnosedSilenceableFailure typeRes = |
| 418 | checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); |
| 419 | if (!typeRes.succeeded()) |
| 420 | return typeRes; |
| 421 | |
| 422 | // Perform other non-types verifications. |
| 423 | if (!forallOp.isNormalized()) |
| 424 | return definiteFailureHelper(transformOp, target: forallOp, |
| 425 | message: "unsupported non-normalized loops" ); |
| 426 | if (forallOp.getNumResults() > 0) |
| 427 | return definiteFailureHelper(transformOp, target: forallOp, |
| 428 | message: "only bufferized scf.forall can be mapped" ); |
| 429 | bool useLinearMapping = forallOp.usesLinearMapping(); |
| 430 | // TODO: This would be more natural with support for Optional<EnumParameter> |
| 431 | // in GPUDeviceMappingAttr. |
| 432 | int64_t maxNumMappingsSupported = |
| 433 | useLinearMapping ? (getMaxEnumValForMappingId() - |
| 434 | static_cast<uint64_t>(MappingId::DimZ)) |
| 435 | : 3; |
| 436 | if (forallOp.getRank() > maxNumMappingsSupported) { |
| 437 | return definiteFailureHelper(transformOp, target: forallOp, |
| 438 | message: "scf.forall with rank > " ) |
| 439 | << maxNumMappingsSupported |
| 440 | << " does not lower for the specified mapping attribute type" ; |
| 441 | } |
| 442 | auto numParallelIterations = |
| 443 | getConstantIntValues(ofrs: forallOp.getMixedUpperBound()); |
| 444 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
| 445 | return definiteFailureHelper( |
| 446 | transformOp, target: forallOp, |
| 447 | message: "requires statically sized, normalized forall op" ); |
| 448 | } |
| 449 | return DiagnosedSilenceableFailure::success(); |
| 450 | } |
| 451 | |
| 452 | /// Struct to return the result of the rewrite of a forall operation. |
| 453 | struct ForallRewriteResult { |
| 454 | SmallVector<int64_t> mappingSizes; |
| 455 | SmallVector<Value> mappingIds; |
| 456 | }; |
| 457 | |
| 458 | /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 459 | template <typename OpTy, typename OperationOrBlock> |
| 460 | static void |
| 461 | replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, |
| 462 | OperationOrBlock *parent, Value replacement, |
| 463 | ArrayRef<int64_t> availableMappingSizes) { |
| 464 | parent->walk([&](OpTy idOp) { |
| 465 | if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) |
| 466 | rewriter.replaceAllUsesWith(idOp.getResult(), replacement); |
| 467 | }); |
| 468 | } |
| 469 | |
| 470 | static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( |
| 471 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 472 | scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, |
| 473 | ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { |
| 474 | LDBG("--start rewriteOneForallCommonImpl" ); |
| 475 | |
| 476 | // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. |
| 477 | auto numParallelIterations = |
| 478 | getConstantIntValues(ofrs: forallOp.getMixedUpperBound()); |
| 479 | assert(forallOp.isNormalized() && numParallelIterations.has_value() && |
| 480 | "requires statically sized, normalized forall op" ); |
| 481 | SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); |
| 482 | SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec = |
| 483 | forallOp.getDeviceMappingAttrs(); |
| 484 | SetVector<Attribute> forallMappingAttrs; |
| 485 | forallMappingAttrs.insert_range(R&: forallMappingAttrsVec); |
| 486 | auto comparator = [](Attribute a, Attribute b) -> bool { |
| 487 | return cast<DeviceMappingAttrInterface>(Val&: a).getMappingId() < |
| 488 | cast<DeviceMappingAttrInterface>(Val&: b).getMappingId(); |
| 489 | }; |
| 490 | |
| 491 | // Step 1.b. In the linear case, compute the max mapping to avoid needlessly |
| 492 | // mapping all dimensions. In the 3-D mapping case we need to map all |
| 493 | // dimensions. |
| 494 | DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( |
| 495 | Val: *llvm::max_element(Range&: forallMappingAttrs, C: comparator)); |
| 496 | DeviceMappingAttrInterface maxLinearMapping; |
| 497 | if (maxMapping.isLinearMapping()) |
| 498 | maxLinearMapping = maxMapping; |
| 499 | for (auto attr : gpuIdBuilder.mappingAttributes) { |
| 500 | // If attr overflows, just skip. |
| 501 | if (maxLinearMapping && comparator(maxLinearMapping, attr)) |
| 502 | continue; |
| 503 | // Try to insert. If element was already present, just continue. |
| 504 | if (!forallMappingAttrs.insert(X: attr)) |
| 505 | continue; |
| 506 | // Otherwise, we have a new insertion without a size -> use size 1. |
| 507 | tmpMappingSizes.push_back(Elt: 1); |
| 508 | } |
| 509 | LDBG("----tmpMappingSizes extracted from scf.forall op: " |
| 510 | << llvm::interleaved(tmpMappingSizes)); |
| 511 | |
| 512 | // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. |
| 513 | SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( |
| 514 | keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator); |
| 515 | LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes)); |
| 516 | LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs)); |
| 517 | |
| 518 | // Step 3. Generate the mappingIdOps using the provided generator. |
| 519 | Location loc = forallOp.getLoc(); |
| 520 | OpBuilder::InsertionGuard guard(rewriter); |
| 521 | rewriter.setInsertionPoint(forallOp); |
| 522 | SmallVector<int64_t> originalBasis(availableMappingSizes); |
| 523 | bool originalBasisWasProvided = !originalBasis.empty(); |
| 524 | if (!originalBasisWasProvided) { |
| 525 | LDBG("----originalBasis was not provided, deriving it and there will be no " |
| 526 | "predication" ); |
| 527 | originalBasis = forallMappingSizes; |
| 528 | while (originalBasis.size() < 3) |
| 529 | originalBasis.push_back(Elt: 1); |
| 530 | } else { |
| 531 | LDBG("----originalBasis was provided, using it, there will be predication" ); |
| 532 | } |
| 533 | LLVM_DEBUG( |
| 534 | llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: " ); |
| 535 | llvm::dbgs() << "\n" ); |
| 536 | |
| 537 | IdBuilderResult builderResult = |
| 538 | gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); |
| 539 | if (!builderResult.errorMsg.empty()) |
| 540 | return definiteFailureHelper(transformOp, target: forallOp, message: builderResult.errorMsg); |
| 541 | |
| 542 | LLVM_DEBUG(DBGS() << builderResult); |
| 543 | |
| 544 | // Step 4. Map the induction variables to the mappingIdOps, this may involve |
| 545 | // a permutation. |
| 546 | SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; |
| 547 | IRMapping bvm; |
| 548 | for (auto [iv, dim] : llvm::zip_equal( |
| 549 | t: forallOp.getInductionVars(), |
| 550 | u: forallMappingAttrs.getArrayRef().take_front(N: forallOp.getRank()))) { |
| 551 | auto mappingAttr = cast<DeviceMappingAttrInterface>(Val: dim); |
| 552 | Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; |
| 553 | LDBG("----map: " << iv << " to " << peIdOp); |
| 554 | bvm.map(from: iv, to: peIdOp); |
| 555 | } |
| 556 | |
| 557 | // Step 5. If the originalBasis is already known, create conditionals to |
| 558 | // predicate the region. Otherwise, the current forall determines the |
| 559 | // originalBasis and no predication occurs. |
| 560 | Value predicate; |
| 561 | if (originalBasisWasProvided) { |
| 562 | for (Value tmpPredicate : builderResult.predicateOps) { |
| 563 | predicate = predicate ? rewriter.create<arith::AndIOp>(location: loc, args&: predicate, |
| 564 | args&: tmpPredicate) |
| 565 | : tmpPredicate; |
| 566 | } |
| 567 | } |
| 568 | |
| 569 | // Step 6. Move the body of forallOp. |
| 570 | // Erase the terminator first, it will not be used. |
| 571 | rewriter.eraseOp(op: forallOp.getTerminator()); |
| 572 | Block *targetBlock; |
| 573 | Block::iterator insertionPoint; |
| 574 | if (predicate) { |
| 575 | // Step 6.a. If predicated, move at the beginning. |
| 576 | auto ifOp = rewriter.create<scf::IfOp>(location: loc, args&: predicate, |
| 577 | /*withElseRegion=*/args: false); |
| 578 | targetBlock = ifOp.thenBlock(); |
| 579 | insertionPoint = ifOp.thenBlock()->begin(); |
| 580 | } else { |
| 581 | // Step 6.b. Otherwise, move inline just at the rewriter insertion |
| 582 | // point. |
| 583 | targetBlock = forallOp->getBlock(); |
| 584 | insertionPoint = rewriter.getInsertionPoint(); |
| 585 | } |
| 586 | Block &sourceBlock = forallOp.getRegion().front(); |
| 587 | targetBlock->getOperations().splice(where: insertionPoint, |
| 588 | L2&: sourceBlock.getOperations()); |
| 589 | |
| 590 | // Step 7. RAUW indices. |
| 591 | for (Value loopIndex : forallOp.getInductionVars()) { |
| 592 | Value threadIdx = bvm.lookup(from: loopIndex); |
| 593 | rewriter.replaceAllUsesWith(from: loopIndex, to: threadIdx); |
| 594 | } |
| 595 | |
| 596 | // Step 8. Erase old op. |
| 597 | rewriter.eraseOp(op: forallOp); |
| 598 | |
| 599 | LDBG("----result forallMappingSizes: " |
| 600 | << llvm::interleaved(forallMappingSizes)); |
| 601 | LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps)); |
| 602 | |
| 603 | result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps}; |
| 604 | return DiagnosedSilenceableFailure::success(); |
| 605 | } |
| 606 | |
| 607 | //===----------------------------------------------------------------------===// |
| 608 | // MapForallToBlocks |
| 609 | //===----------------------------------------------------------------------===// |
| 610 | |
| 611 | DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( |
| 612 | RewriterBase &rewriter, TransformOpInterface transformOp, |
| 613 | scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, |
| 614 | const GpuIdBuilder &gpuIdBuilder) { |
| 615 | LDBG("Start mapForallToBlocksImpl" ); |
| 616 | |
| 617 | { |
| 618 | // GPU-specific verifications. There is no better place to anchor |
| 619 | // those right now: the ForallOp is target-independent and the transform |
| 620 | // op does not apply to individual ForallOp. |
| 621 | DiagnosedSilenceableFailure diag = |
| 622 | verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); |
| 623 | if (!diag.succeeded()) |
| 624 | return diag; |
| 625 | } |
| 626 | |
| 627 | Location loc = forallOp.getLoc(); |
| 628 | Block *parentBlock = forallOp->getBlock(); |
| 629 | Value zero; |
| 630 | { |
| 631 | // Create an early zero index value for replacements and immediately reset |
| 632 | // the insertion point. |
| 633 | OpBuilder::InsertionGuard guard(rewriter); |
| 634 | rewriter.setInsertionPointToStart(parentBlock); |
| 635 | zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 636 | } |
| 637 | |
| 638 | ForallRewriteResult rewriteResult; |
| 639 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
| 640 | rewriter, transformOp, forallOp, |
| 641 | /*availableMappingSizes=*/gridDims, result&: rewriteResult, gpuIdBuilder); |
| 642 | |
| 643 | // Return if anything goes wrong, use silenceable failure as a match |
| 644 | // failure. |
| 645 | if (!diag.succeeded()) |
| 646 | return diag; |
| 647 | |
| 648 | // If gridDims was not provided already, set it from the return. |
| 649 | if (gridDims.empty()) { |
| 650 | gridDims = rewriteResult.mappingSizes; |
| 651 | while (gridDims.size() < 3) |
| 652 | gridDims.push_back(Elt: 1); |
| 653 | } |
| 654 | assert(gridDims.size() == 3 && "Need 3-D gridDims" ); |
| 655 | |
| 656 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 657 | // Here, the result of mapping determines the available mapping sizes. |
| 658 | replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parent: parentBlock, replacement: zero, |
| 659 | availableMappingSizes: rewriteResult.mappingSizes); |
| 660 | |
| 661 | return DiagnosedSilenceableFailure::success(); |
| 662 | } |
| 663 | |
| 664 | DiagnosedSilenceableFailure |
| 665 | mlir::transform::gpu::findTopLevelForallOp(Operation *target, |
| 666 | scf::ForallOp &topLevelForallOp, |
| 667 | TransformOpInterface transformOp) { |
| 668 | auto walkResult = target->walk(callback: [&](scf::ForallOp forallOp) { |
| 669 | if (forallOp->getParentOfType<scf::ForallOp>()) |
| 670 | return WalkResult::advance(); |
| 671 | if (topLevelForallOp) |
| 672 | // TODO: Handle multiple forall if they are independent. |
| 673 | return WalkResult::interrupt(); |
| 674 | topLevelForallOp = forallOp; |
| 675 | return WalkResult::advance(); |
| 676 | }); |
| 677 | |
| 678 | if (walkResult.wasInterrupted() || !topLevelForallOp) |
| 679 | return transformOp.emitSilenceableError() |
| 680 | << "could not find a unique topLevel scf.forall" ; |
| 681 | return DiagnosedSilenceableFailure::success(); |
| 682 | } |
| 683 | |
| 684 | DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( |
| 685 | transform::TransformRewriter &rewriter, Operation *target, |
| 686 | ApplyToEachResultList &results, transform::TransformState &state) { |
| 687 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(Val: target); |
| 688 | auto transformOp = cast<TransformOpInterface>(Val: getOperation()); |
| 689 | |
| 690 | if (!getGenerateGpuLaunch() && !gpuLaunch) { |
| 691 | DiagnosedSilenceableFailure diag = |
| 692 | emitSilenceableError() |
| 693 | << "Given target is not gpu.launch, set `generate_gpu_launch` " |
| 694 | "attribute" ; |
| 695 | diag.attachNote(loc: target->getLoc()) << "when applied to this payload op" ; |
| 696 | return diag; |
| 697 | } |
| 698 | |
| 699 | scf::ForallOp topLevelForallOp; |
| 700 | DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( |
| 701 | target, topLevelForallOp, transformOp); |
| 702 | if (!diag.succeeded()) { |
| 703 | diag.attachNote(loc: target->getLoc()) << "when applied to this payload op" ; |
| 704 | return diag; |
| 705 | } |
| 706 | assert(topLevelForallOp && "expect an scf.forall" ); |
| 707 | |
| 708 | SmallVector<int64_t> gridDims{getGridDims()}; |
| 709 | if (!getGenerateGpuLaunch() && gridDims.size() != 3) |
| 710 | return transformOp.emitDefiniteFailure(message: "transform require size-3 mapping" ); |
| 711 | |
| 712 | OpBuilder::InsertionGuard guard(rewriter); |
| 713 | rewriter.setInsertionPoint(topLevelForallOp); |
| 714 | |
| 715 | // Generate gpu launch here and move the forall inside |
| 716 | if (getGenerateGpuLaunch()) { |
| 717 | DiagnosedSilenceableFailure diag = |
| 718 | createGpuLaunch(rewriter, loc: target->getLoc(), transformOp, launchOp&: gpuLaunch); |
| 719 | if (!diag.succeeded()) |
| 720 | return diag; |
| 721 | |
| 722 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
| 723 | Operation *newForallOp = rewriter.clone(op&: *topLevelForallOp); |
| 724 | rewriter.eraseOp(op: topLevelForallOp); |
| 725 | topLevelForallOp = cast<scf::ForallOp>(Val: newForallOp); |
| 726 | } |
| 727 | |
| 728 | // The BlockIdBuilder adapts to whatever is thrown at it. |
| 729 | bool useLinearMapping = false; |
| 730 | if (topLevelForallOp.getMapping()) |
| 731 | useLinearMapping = topLevelForallOp.usesLinearMapping(); |
| 732 | |
| 733 | FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = |
| 734 | topLevelForallOp.getDeviceMaskingAttr(); |
| 735 | assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr" ); |
| 736 | assert((!*maybeMaskingAttr || useLinearMapping) && |
| 737 | "masking requires linear mapping" ); |
| 738 | |
| 739 | GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping, |
| 740 | *maybeMaskingAttr); |
| 741 | |
| 742 | diag = mlir::transform::gpu::mapForallToBlocksImpl( |
| 743 | rewriter, transformOp, forallOp: topLevelForallOp, gridDims, gpuIdBuilder: gpuBlockIdBuilder); |
| 744 | if (!diag.succeeded()) |
| 745 | return diag; |
| 746 | |
| 747 | // Set the GPU launch configuration for the grid dims late, this is |
| 748 | // subject to IR inspection. |
| 749 | diag = alterGpuLaunch(rewriter, gpuLaunch, |
| 750 | transformOp: cast<TransformOpInterface>(Val: getOperation()), gridDimX: gridDims[0], |
| 751 | gridDimY: gridDims[1], gridDimZ: gridDims[2]); |
| 752 | |
| 753 | results.push_back(op: gpuLaunch); |
| 754 | return diag; |
| 755 | } |
| 756 | |
| 757 | LogicalResult transform::MapForallToBlocks::verify() { |
| 758 | if (!getGridDims().empty() && getGridDims().size() != 3) { |
| 759 | return emitOpError() << "transform requires empty or size-3 grid_dims" ; |
| 760 | } |
| 761 | return success(); |
| 762 | } |
| 763 | |
| 764 | //===----------------------------------------------------------------------===// |
| 765 | // MapNestedForallToThreads |
| 766 | //===----------------------------------------------------------------------===// |
| 767 | |
| 768 | static DiagnosedSilenceableFailure checkMappingSpec( |
| 769 | std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, |
| 770 | ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, |
| 771 | int factor, bool useLinearMapping = false) { |
| 772 | if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { |
| 773 | auto diag = definiteFailureHelper( |
| 774 | transformOp, target: forallOp, |
| 775 | message: Twine("3-D mapping: size of threadIdx.x must be a multiple of " ) + |
| 776 | Twine(factor)); |
| 777 | return diag; |
| 778 | } |
| 779 | if (computeProduct(basis: numParallelIterations) * factor > |
| 780 | computeProduct(basis: blockOrGridSizes)) { |
| 781 | auto diag = definiteFailureHelper( |
| 782 | transformOp, target: forallOp, |
| 783 | message: Twine("the number of required parallel resources (blocks or " |
| 784 | "threads) " ) + |
| 785 | Twine(computeProduct(basis: numParallelIterations) * factor) + |
| 786 | " overflows the number of available resources " + |
| 787 | Twine(computeProduct(basis: blockOrGridSizes))); |
| 788 | return diag; |
| 789 | } |
| 790 | return DiagnosedSilenceableFailure::success(); |
| 791 | } |
| 792 | |
| 793 | static DiagnosedSilenceableFailure |
| 794 | getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, |
| 795 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, |
| 796 | int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { |
| 797 | DeviceMappingAttrInterface mappingAttr = |
| 798 | forallOp.getDeviceMappingAttrs().front(); |
| 799 | bool useLinearMapping = mappingAttr.isLinearMapping(); |
| 800 | |
| 801 | // Sanity checks that may result in runtime verification errors. |
| 802 | auto numParallelIterations = |
| 803 | getConstantIntValues(ofrs: (forallOp.getMixedUpperBound())); |
| 804 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
| 805 | return definiteFailureHelper( |
| 806 | transformOp, target: forallOp, |
| 807 | message: "requires statically sized, normalized forall op" ); |
| 808 | } |
| 809 | int64_t factor = 1; |
| 810 | if (isa<GPUWarpgroupMappingAttr>(Val: mappingAttr)) { |
| 811 | factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; |
| 812 | } else if (isa<GPUWarpMappingAttr>(Val: mappingAttr)) { |
| 813 | factor = warpSize; |
| 814 | } |
| 815 | DiagnosedSilenceableFailure diag = |
| 816 | checkMappingSpec(transformOp, forallOp, numParallelIterations: numParallelIterations.value(), |
| 817 | blockOrGridSizes: blockSizes, factor, useLinearMapping); |
| 818 | if (!diag.succeeded()) |
| 819 | return diag; |
| 820 | |
| 821 | FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = |
| 822 | forallOp.getDeviceMaskingAttr(); |
| 823 | assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr" ); |
| 824 | assert((!*maybeMaskingAttr || useLinearMapping) && |
| 825 | "masking requires linear mapping" ); |
| 826 | |
| 827 | // Start mapping. |
| 828 | MLIRContext *ctx = forallOp.getContext(); |
| 829 | gpuIdBuilder = |
| 830 | TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) |
| 831 | .Case(caseFn: [&](GPUWarpgroupMappingAttr) { |
| 832 | return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping, |
| 833 | *maybeMaskingAttr); |
| 834 | }) |
| 835 | .Case(caseFn: [&](GPUWarpMappingAttr) { |
| 836 | return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping, |
| 837 | *maybeMaskingAttr); |
| 838 | }) |
| 839 | .Case(caseFn: [&](GPUThreadMappingAttr) { |
| 840 | return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr); |
| 841 | }) |
| 842 | .Case(caseFn: [&](GPULaneMappingAttr) { |
| 843 | return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping, |
| 844 | *maybeMaskingAttr); |
| 845 | }) |
| 846 | .Default(defaultFn: [&](DeviceMappingAttrInterface) -> GpuIdBuilder { |
| 847 | llvm_unreachable("unknown mapping attribute" ); |
| 848 | }); |
| 849 | return DiagnosedSilenceableFailure::success(); |
| 850 | } |
| 851 | |
| 852 | DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( |
| 853 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 854 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, |
| 855 | bool syncAfterDistribute) { |
| 856 | |
| 857 | { |
| 858 | // GPU-specific verifications. There is no better place to anchor |
| 859 | // those right now: the ForallOp is target-independent and the transform |
| 860 | // op does not apply to individual ForallOp. |
| 861 | DiagnosedSilenceableFailure diag = |
| 862 | verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); |
| 863 | if (!diag.succeeded()) |
| 864 | return diag; |
| 865 | } |
| 866 | |
| 867 | GpuIdBuilder gpuIdBuilder; |
| 868 | { |
| 869 | // Try to construct the id builder, if it fails, return. |
| 870 | DiagnosedSilenceableFailure diag = getThreadIdBuilder( |
| 871 | transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); |
| 872 | if (!diag.succeeded()) |
| 873 | return diag; |
| 874 | } |
| 875 | |
| 876 | Location loc = forallOp.getLoc(); |
| 877 | OpBuilder::InsertionGuard g(rewriter); |
| 878 | // Insert after to allow for syncthreads after `forall` is erased. |
| 879 | rewriter.setInsertionPointAfter(forallOp); |
| 880 | ForallRewriteResult rewriteResult; |
| 881 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
| 882 | rewriter, transformOp, forallOp, availableMappingSizes: blockSizes, result&: rewriteResult, gpuIdBuilder); |
| 883 | if (!diag.succeeded()) |
| 884 | return diag; |
| 885 | // Add a syncthreads if needed. TODO: warpsync |
| 886 | if (syncAfterDistribute) |
| 887 | rewriter.create<BarrierOp>(location: loc); |
| 888 | |
| 889 | return DiagnosedSilenceableFailure::success(); |
| 890 | } |
| 891 | |
| 892 | DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( |
| 893 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
| 894 | Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, |
| 895 | bool syncAfterDistribute) { |
| 896 | LDBG("Start mapNestedForallToThreadsImpl" ); |
| 897 | if (blockDims.size() != 3) { |
| 898 | return definiteFailureHelper(transformOp, target, |
| 899 | message: "requires size-3 thread mapping" ); |
| 900 | } |
| 901 | |
| 902 | // Create an early zero index value for replacements. |
| 903 | Location loc = target->getLoc(); |
| 904 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 905 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
| 906 | WalkResult walkResult = target->walk(callback: [&](scf::ForallOp forallOp) { |
| 907 | diag = mlir::transform::gpu::mapOneForallToThreadsImpl( |
| 908 | rewriter, transformOp, forallOp, blockSizes: blockDims, warpSize, |
| 909 | syncAfterDistribute); |
| 910 | if (diag.isDefiniteFailure()) |
| 911 | return WalkResult::interrupt(); |
| 912 | if (diag.succeeded()) |
| 913 | return WalkResult::skip(); |
| 914 | return WalkResult::advance(); |
| 915 | }); |
| 916 | if (walkResult.wasInterrupted()) |
| 917 | return diag; |
| 918 | |
| 919 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
| 920 | // Here, the result of mapping determines the available mapping sizes. |
| 921 | replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, parent: target, replacement: zero, |
| 922 | availableMappingSizes: blockDims); |
| 923 | |
| 924 | return DiagnosedSilenceableFailure::success(); |
| 925 | } |
| 926 | |
| 927 | DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( |
| 928 | transform::TransformRewriter &rewriter, Operation *target, |
| 929 | ApplyToEachResultList &results, TransformState &state) { |
| 930 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(Val: target); |
| 931 | auto transformOp = cast<TransformOpInterface>(Val: getOperation()); |
| 932 | |
| 933 | // Basic high-level verifications. |
| 934 | if (!gpuLaunch) |
| 935 | return emitSilenceableError() << "Given target is not a gpu.launch" ; |
| 936 | |
| 937 | // Mapping to block ids. |
| 938 | SmallVector<int64_t> blockDims{getBlockDims()}; |
| 939 | DiagnosedSilenceableFailure diag = |
| 940 | checkGpuLimits(transformOp, gridDimX: std::nullopt, gridDimY: std::nullopt, gridDimZ: std::nullopt, |
| 941 | blockDimX: blockDims[0], blockDimY: blockDims[1], blockDimZ: blockDims[2]); |
| 942 | if (diag.isSilenceableFailure()) { |
| 943 | diag.attachNote(loc: getLoc()) << getBlockDimsAttrName() << " is too large" ; |
| 944 | return diag; |
| 945 | } |
| 946 | |
| 947 | // Set the GPU launch configuration for the block dims early, this is not |
| 948 | // subject to IR inspection. |
| 949 | diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, gridDimX: std::nullopt, |
| 950 | gridDimY: std::nullopt, gridDimZ: std::nullopt, blockDimX: blockDims[0], blockDimY: blockDims[1], |
| 951 | blockDimZ: blockDims[2]); |
| 952 | |
| 953 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
| 954 | diag = |
| 955 | mapNestedForallToThreadsImpl(rewriter, transformOp, target: gpuLaunch, blockDims, |
| 956 | warpSize: getWarpSize(), syncAfterDistribute: getSyncAfterDistribute()); |
| 957 | |
| 958 | results.push_back(op: gpuLaunch.getOperation()); |
| 959 | return diag; |
| 960 | } |
| 961 | |
| 962 | //===----------------------------------------------------------------------===// |
| 963 | // Transform op registration |
| 964 | //===----------------------------------------------------------------------===// |
| 965 | |
| 966 | namespace { |
| 967 | /// Registers new ops and declares PDL as dependent dialect since the |
| 968 | /// additional ops are using PDL types for operands and results. |
| 969 | class GPUTransformDialectExtension |
| 970 | : public transform::TransformDialectExtension< |
| 971 | GPUTransformDialectExtension> { |
| 972 | public: |
| 973 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) |
| 974 | |
| 975 | GPUTransformDialectExtension() { |
| 976 | declareGeneratedDialect<GPUDialect>(); |
| 977 | declareGeneratedDialect<amdgpu::AMDGPUDialect>(); |
| 978 | declareGeneratedDialect<arith::ArithDialect>(); |
| 979 | declareGeneratedDialect<scf::SCFDialect>(); |
| 980 | registerTransformOps< |
| 981 | #define GET_OP_LIST |
| 982 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
| 983 | >(); |
| 984 | } |
| 985 | }; |
| 986 | } // namespace |
| 987 | |
| 988 | #define GET_OP_CLASSES |
| 989 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
| 990 | |
| 991 | void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 992 | registry.addExtensions<GPUTransformDialectExtension>(); |
| 993 | } |
| 994 | |