| 1 | //===- Utils.cpp - Utils for GPU transform ops ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 15 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| 18 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
| 19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 20 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 21 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 22 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 24 | #include "mlir/IR/AffineExpr.h" |
| 25 | #include "mlir/IR/Builders.h" |
| 26 | #include "mlir/IR/BuiltinAttributes.h" |
| 27 | #include "mlir/IR/IRMapping.h" |
| 28 | #include "mlir/IR/MLIRContext.h" |
| 29 | #include "mlir/IR/OpDefinition.h" |
| 30 | #include "mlir/IR/Value.h" |
| 31 | #include "mlir/IR/Visitors.h" |
| 32 | #include "mlir/Support/LLVM.h" |
| 33 | #include "llvm/ADT/STLExtras.h" |
| 34 | #include "llvm/ADT/SmallVector.h" |
| 35 | #include "llvm/ADT/TypeSwitch.h" |
| 36 | #include "llvm/Support/Debug.h" |
| 37 | #include "llvm/Support/InterleavedRange.h" |
| 38 | |
| 39 | using namespace mlir; |
| 40 | using namespace mlir::gpu; |
| 41 | using namespace mlir::transform; |
| 42 | using namespace mlir::transform::gpu; |
| 43 | |
| 44 | #define DEBUG_TYPE "gpu-transforms" |
| 45 | |
| 46 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 47 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
| 48 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
| 49 | |
| 50 | /// Return a flattened thread id for the workgroup with given sizes. |
| 51 | template <typename ThreadOrBlockIdOp> |
| 52 | static Value buildLinearId(RewriterBase &rewriter, Location loc, |
| 53 | ArrayRef<OpFoldResult> originalBasisOfr) { |
| 54 | LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr: " |
| 55 | << llvm::interleaved(originalBasisOfr) << "\n" ); |
| 56 | assert(originalBasisOfr.size() == 3 && "expected 3 sizes" ); |
| 57 | IndexType indexType = rewriter.getIndexType(); |
| 58 | AffineExpr tx, ty, tz, bdx, bdy; |
| 59 | bindDims(ctx: rewriter.getContext(), exprs&: tx, exprs&: ty, exprs&: tz); |
| 60 | bindSymbols(ctx: rewriter.getContext(), exprs&: bdx, exprs&: bdy); |
| 61 | SmallVector<OpFoldResult> vals{ |
| 62 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x) |
| 63 | .getResult(), |
| 64 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y) |
| 65 | .getResult(), |
| 66 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z) |
| 67 | .getResult(), |
| 68 | originalBasisOfr[0], originalBasisOfr[1]}; |
| 69 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
| 70 | b&: rewriter, loc, expr: tx + ty * bdx + tz * bdx * bdy, operands: vals); |
| 71 | return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr); |
| 72 | } |
| 73 | |
| 74 | /// Create a linear id builder that takes the `originalBasisOfr` and decompose |
| 75 | /// it in the basis of `forallMappingSizes`. The linear id builder returns an |
| 76 | /// n-D vector of ids for indexing and 1-D size + id for predicate generation. |
| 77 | template <typename ThreadOrBlockIdOp> |
| 78 | static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { |
| 79 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
| 80 | ArrayRef<int64_t> forallMappingSizes, |
| 81 | ArrayRef<int64_t> originalBasis) { |
| 82 | SmallVector<OpFoldResult> originalBasisOfr = |
| 83 | getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: originalBasis); |
| 84 | OpFoldResult linearId = |
| 85 | buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr); |
| 86 | // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in |
| 87 | // "row-major" order. |
| 88 | SmallVector<int64_t> reverseBasisSizes(llvm::reverse(C&: forallMappingSizes)); |
| 89 | SmallVector<int64_t> strides = computeStrides(sizes: reverseBasisSizes); |
| 90 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
| 91 | OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply( |
| 92 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {linearId}); |
| 93 | SmallVector<AffineExpr> delinearizingExprs = delinearize(linearIndex: d0, strides); |
| 94 | SmallVector<Value> ids; |
| 95 | // Reverse back to be in [0 .. n] order. |
| 96 | for (AffineExpr e : llvm::reverse(C&: delinearizingExprs)) { |
| 97 | ids.push_back( |
| 98 | Elt: affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId})); |
| 99 | } |
| 100 | |
| 101 | LLVM_DEBUG(DBGS() << "--delinearization basis: " |
| 102 | << llvm::interleaved(reverseBasisSizes) << "\n" ; |
| 103 | DBGS() << "--delinearization strides: " |
| 104 | << llvm::interleaved(strides) << "\n" ; |
| 105 | DBGS() << "--delinearization exprs: " |
| 106 | << llvm::interleaved(delinearizingExprs) << "\n" ; |
| 107 | DBGS() << "--ids: " << llvm::interleaved(ids) << "\n" ); |
| 108 | |
| 109 | // Return n-D ids for indexing and 1-D size + id for predicate generation. |
| 110 | return IdBuilderResult{ |
| 111 | /*mappingIdOps=*/ids, |
| 112 | /*availableMappingSizes=*/ |
| 113 | SmallVector<int64_t>{computeProduct(basis: originalBasis)}, |
| 114 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
| 115 | // scaled back into the original basis to provide tight |
| 116 | // activeMappingSizes quantities for predication. |
| 117 | /*activeMappingSizes=*/ |
| 118 | SmallVector<int64_t>{computeProduct(basis: forallMappingSizes) * multiplicity}, |
| 119 | /*activeIdOps=*/SmallVector<Value>{cast<Value>(Val&: linearId)}}; |
| 120 | }; |
| 121 | |
| 122 | return res; |
| 123 | } |
| 124 | |
| 125 | /// Create a simple 3-D id builder that takes the `originalBasisOfr` |
| 126 | /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes |
| 127 | /// + ids for predicate generation. |
| 128 | template <typename ThreadOrBlockIdOp> |
| 129 | static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) { |
| 130 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
| 131 | ArrayRef<int64_t> forallMappingSizes, |
| 132 | ArrayRef<int64_t> originalBasis) { |
| 133 | IndexType indexType = rewriter.getIndexType(); |
| 134 | SmallVector<Value> ids{ |
| 135 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x), |
| 136 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y), |
| 137 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)}; |
| 138 | // In the 3-D mapping case, scale the first dimension by the multiplicity. |
| 139 | SmallVector<Value> scaledIds = ids; |
| 140 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
| 141 | scaledIds[0] = cast<Value>(Val: affine::makeComposedFoldedAffineApply( |
| 142 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {scaledIds[0]})); |
| 143 | // In the 3-D mapping case, unscale the first dimension by the multiplicity. |
| 144 | SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes); |
| 145 | forallMappingSizeInOriginalBasis[0] *= multiplicity; |
| 146 | return IdBuilderResult{ |
| 147 | /*mappingIdOps=*/scaledIds, |
| 148 | /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis}, |
| 149 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
| 150 | // scaled back into the original basis to provide tight |
| 151 | // activeMappingSizes quantities for predication. |
| 152 | /*activeMappingSizes=*/ |
| 153 | SmallVector<int64_t>{forallMappingSizeInOriginalBasis}, |
| 154 | /*activeIdOps=*/ids}; |
| 155 | }; |
| 156 | return res; |
| 157 | } |
| 158 | |
| 159 | namespace mlir { |
| 160 | namespace transform { |
| 161 | namespace gpu { |
| 162 | |
| 163 | GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping, |
| 164 | const MappingIdBuilderFnType &fn) |
| 165 | : mappingAttributes(), idBuilder() { |
| 166 | if (useLinearMapping) { |
| 167 | for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0), |
| 168 | e = getMaxEnumValForMappingId(); |
| 169 | d <= e; ++d) |
| 170 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
| 171 | } else { |
| 172 | for (uint64_t d = static_cast<uint64_t>(MappingId::DimX), |
| 173 | e = static_cast<uint64_t>(MappingId::DimZ); |
| 174 | d <= e; ++d) |
| 175 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
| 180 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
| 181 | return GPUBlockMappingAttr::get(ctx, id); |
| 182 | }) { |
| 183 | idBuilder = useLinearMapping |
| 184 | ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1) |
| 185 | : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1); |
| 186 | } |
| 187 | |
| 188 | GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, |
| 189 | bool useLinearMapping) |
| 190 | : GpuIdBuilder(ctx, useLinearMapping, |
| 191 | [](MLIRContext *ctx, MappingId id) { |
| 192 | return GPUWarpgroupMappingAttr::get(ctx, id); |
| 193 | }), |
| 194 | warpSize(warpSize) { |
| 195 | idBuilder = useLinearMapping |
| 196 | ? commonLinearIdBuilderFn<ThreadIdOp>( |
| 197 | /*multiplicity=*/kNumWarpsPerGroup * warpSize) |
| 198 | : common3DIdBuilderFn<ThreadIdOp>( |
| 199 | /*multiplicity=*/kNumWarpsPerGroup * warpSize); |
| 200 | } |
| 201 | |
| 202 | GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, |
| 203 | bool useLinearMapping) |
| 204 | : GpuIdBuilder(ctx, useLinearMapping, |
| 205 | [](MLIRContext *ctx, MappingId id) { |
| 206 | return GPUWarpMappingAttr::get(ctx, id); |
| 207 | }), |
| 208 | warpSize(warpSize) { |
| 209 | idBuilder = |
| 210 | useLinearMapping |
| 211 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize) |
| 212 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); |
| 213 | } |
| 214 | |
| 215 | GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
| 216 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
| 217 | return GPUThreadMappingAttr::get(ctx, id); |
| 218 | }) { |
| 219 | idBuilder = useLinearMapping |
| 220 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1) |
| 221 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); |
| 222 | } |
| 223 | |
| 224 | DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, |
| 225 | std::optional<int64_t> gridDimX, |
| 226 | std::optional<int64_t> gridDimY, |
| 227 | std::optional<int64_t> gridDimZ, |
| 228 | std::optional<int64_t> blockDimX, |
| 229 | std::optional<int64_t> blockDimY, |
| 230 | std::optional<int64_t> blockDimZ) { |
| 231 | |
| 232 | // TODO: pass a configuration object to set the limits properly. |
| 233 | |
| 234 | if ((blockDimX.value_or(u: 1) * blockDimY.value_or(u: 1) * blockDimZ.value_or(u: 1)) > |
| 235 | kMaxTotalBlockdim || |
| 236 | (gridDimX.value_or(u: 1) * gridDimY.value_or(u: 1) * gridDimZ.value_or(u: 1)) > |
| 237 | kMaxTotalGriddim || |
| 238 | blockDimX.value_or(u: 1) > kMaxBlockdimx || |
| 239 | blockDimY.value_or(u: 1) > kMaxBlockdimy || |
| 240 | blockDimZ.value_or(u: 1) > kMaxBlockdimz || |
| 241 | gridDimY.value_or(u: 1) > kMaxGriddimy || |
| 242 | gridDimZ.value_or(u: 1) > kMaxGriddimz || |
| 243 | gridDimX.value_or(u: 1) > kMaxGriddimx) { |
| 244 | return transformOp.emitSilenceableError() |
| 245 | << "Trying to launch a GPU kernel with grid_dims = (" |
| 246 | << gridDimX.value_or(u: 1) << ", " << gridDimY.value_or(u: 1) << ", " |
| 247 | << gridDimZ.value_or(u: 1) << ") block_dims = (" |
| 248 | << blockDimX.value_or(u: 1) << ", " << blockDimY.value_or(u: 1) << ", " |
| 249 | << blockDimZ.value_or(u: 1) << "). It is larger than the limits." ; |
| 250 | } |
| 251 | return DiagnosedSilenceableFailure::success(); |
| 252 | } |
| 253 | |
| 254 | DiagnosedSilenceableFailure createGpuLaunch( |
| 255 | RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, |
| 256 | LaunchOp &launchOp, std::optional<int64_t> gridDimX, |
| 257 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
| 258 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
| 259 | std::optional<int64_t> blockDimZ) { |
| 260 | DiagnosedSilenceableFailure diag = |
| 261 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
| 262 | blockDimY, blockDimZ); |
| 263 | if (!diag.succeeded()) |
| 264 | return diag; |
| 265 | |
| 266 | auto createConst = [&](int dim) { |
| 267 | return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: dim); |
| 268 | }; |
| 269 | OpBuilder::InsertionGuard guard(rewriter); |
| 270 | Value one = createConst(1); |
| 271 | Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one; |
| 272 | Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one; |
| 273 | Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one; |
| 274 | Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; |
| 275 | Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; |
| 276 | Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; |
| 277 | launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ, |
| 278 | blkSizeX, blkSizeY, blkSizeZ); |
| 279 | rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); |
| 280 | rewriter.create<TerminatorOp>(loc); |
| 281 | return DiagnosedSilenceableFailure::success(); |
| 282 | } |
| 283 | |
| 284 | /// Alter kernel configuration of the given kernel. |
| 285 | DiagnosedSilenceableFailure alterGpuLaunch( |
| 286 | RewriterBase &rewriter, LaunchOp gpuLaunch, |
| 287 | TransformOpInterface transformOp, std::optional<int64_t> gridDimX, |
| 288 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
| 289 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
| 290 | std::optional<int64_t> blockDimZ) { |
| 291 | DiagnosedSilenceableFailure diag = |
| 292 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
| 293 | blockDimY, blockDimZ); |
| 294 | if (!diag.succeeded()) |
| 295 | return diag; |
| 296 | |
| 297 | KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); |
| 298 | OpBuilder::InsertionGuard guard(rewriter); |
| 299 | rewriter.setInsertionPointAfterValue(currentBlockdim.x); |
| 300 | auto createConstValue = [&](int dim) { |
| 301 | return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(), |
| 302 | dim); |
| 303 | }; |
| 304 | |
| 305 | if (gridDimX.has_value()) |
| 306 | gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value())); |
| 307 | if (gridDimY.has_value()) |
| 308 | gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value())); |
| 309 | if (gridDimZ.has_value()) |
| 310 | gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value())); |
| 311 | if (blockDimX.has_value()) |
| 312 | gpuLaunch.getBlockSizeXMutable().assign( |
| 313 | createConstValue(blockDimX.value())); |
| 314 | if (blockDimY.has_value()) |
| 315 | gpuLaunch.getBlockSizeYMutable().assign( |
| 316 | createConstValue(blockDimY.value())); |
| 317 | if (blockDimZ.has_value()) |
| 318 | gpuLaunch.getBlockSizeZMutable().assign( |
| 319 | createConstValue(blockDimZ.value())); |
| 320 | return DiagnosedSilenceableFailure::success(); |
| 321 | } |
| 322 | |
| 323 | } // namespace gpu |
| 324 | } // namespace transform |
| 325 | } // namespace mlir |
| 326 | |