1 | //===- Utils.cpp - Utils for GPU transform ops ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
18 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
21 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
22 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/IR/AffineExpr.h" |
25 | #include "mlir/IR/Builders.h" |
26 | #include "mlir/IR/BuiltinAttributes.h" |
27 | #include "mlir/IR/IRMapping.h" |
28 | #include "mlir/IR/MLIRContext.h" |
29 | #include "mlir/IR/OpDefinition.h" |
30 | #include "mlir/IR/Value.h" |
31 | #include "mlir/IR/Visitors.h" |
32 | #include "mlir/Support/LLVM.h" |
33 | #include "llvm/ADT/STLExtras.h" |
34 | #include "llvm/ADT/SmallVector.h" |
35 | #include "llvm/ADT/TypeSwitch.h" |
36 | #include "llvm/Support/Debug.h" |
37 | #include "llvm/Support/InterleavedRange.h" |
38 | |
39 | using namespace mlir; |
40 | using namespace mlir::gpu; |
41 | using namespace mlir::transform; |
42 | using namespace mlir::transform::gpu; |
43 | |
44 | #define DEBUG_TYPE "gpu-transforms" |
45 | |
46 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
47 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
48 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
49 | |
50 | /// Return a flattened thread id for the workgroup with given sizes. |
51 | template <typename ThreadOrBlockIdOp> |
52 | static Value buildLinearId(RewriterBase &rewriter, Location loc, |
53 | ArrayRef<OpFoldResult> originalBasisOfr) { |
54 | LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr: " |
55 | << llvm::interleaved(originalBasisOfr) << "\n" ); |
56 | assert(originalBasisOfr.size() == 3 && "expected 3 sizes" ); |
57 | IndexType indexType = rewriter.getIndexType(); |
58 | AffineExpr tx, ty, tz, bdx, bdy; |
59 | bindDims(ctx: rewriter.getContext(), exprs&: tx, exprs&: ty, exprs&: tz); |
60 | bindSymbols(ctx: rewriter.getContext(), exprs&: bdx, exprs&: bdy); |
61 | SmallVector<OpFoldResult> vals{ |
62 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x) |
63 | .getResult(), |
64 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y) |
65 | .getResult(), |
66 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z) |
67 | .getResult(), |
68 | originalBasisOfr[0], originalBasisOfr[1]}; |
69 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
70 | b&: rewriter, loc, expr: tx + ty * bdx + tz * bdx * bdy, operands: vals); |
71 | return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr); |
72 | } |
73 | |
74 | /// Create a linear id builder that takes the `originalBasisOfr` and decompose |
75 | /// it in the basis of `forallMappingSizes`. The linear id builder returns an |
76 | /// n-D vector of ids for indexing and 1-D size + id for predicate generation. |
77 | template <typename ThreadOrBlockIdOp> |
78 | static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { |
79 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
80 | ArrayRef<int64_t> forallMappingSizes, |
81 | ArrayRef<int64_t> originalBasis) { |
82 | SmallVector<OpFoldResult> originalBasisOfr = |
83 | getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: originalBasis); |
84 | OpFoldResult linearId = |
85 | buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr); |
86 | // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in |
87 | // "row-major" order. |
88 | SmallVector<int64_t> reverseBasisSizes(llvm::reverse(C&: forallMappingSizes)); |
89 | SmallVector<int64_t> strides = computeStrides(sizes: reverseBasisSizes); |
90 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
91 | OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply( |
92 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {linearId}); |
93 | SmallVector<AffineExpr> delinearizingExprs = delinearize(linearIndex: d0, strides); |
94 | SmallVector<Value> ids; |
95 | // Reverse back to be in [0 .. n] order. |
96 | for (AffineExpr e : llvm::reverse(C&: delinearizingExprs)) { |
97 | ids.push_back( |
98 | Elt: affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId})); |
99 | } |
100 | |
101 | LLVM_DEBUG(DBGS() << "--delinearization basis: " |
102 | << llvm::interleaved(reverseBasisSizes) << "\n" ; |
103 | DBGS() << "--delinearization strides: " |
104 | << llvm::interleaved(strides) << "\n" ; |
105 | DBGS() << "--delinearization exprs: " |
106 | << llvm::interleaved(delinearizingExprs) << "\n" ; |
107 | DBGS() << "--ids: " << llvm::interleaved(ids) << "\n" ); |
108 | |
109 | // Return n-D ids for indexing and 1-D size + id for predicate generation. |
110 | return IdBuilderResult{ |
111 | /*mappingIdOps=*/ids, |
112 | /*availableMappingSizes=*/ |
113 | SmallVector<int64_t>{computeProduct(basis: originalBasis)}, |
114 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
115 | // scaled back into the original basis to provide tight |
116 | // activeMappingSizes quantities for predication. |
117 | /*activeMappingSizes=*/ |
118 | SmallVector<int64_t>{computeProduct(basis: forallMappingSizes) * multiplicity}, |
119 | /*activeIdOps=*/SmallVector<Value>{cast<Value>(Val&: linearId)}}; |
120 | }; |
121 | |
122 | return res; |
123 | } |
124 | |
125 | /// Create a simple 3-D id builder that takes the `originalBasisOfr` |
126 | /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes |
127 | /// + ids for predicate generation. |
128 | template <typename ThreadOrBlockIdOp> |
129 | static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) { |
130 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
131 | ArrayRef<int64_t> forallMappingSizes, |
132 | ArrayRef<int64_t> originalBasis) { |
133 | IndexType indexType = rewriter.getIndexType(); |
134 | SmallVector<Value> ids{ |
135 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x), |
136 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y), |
137 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)}; |
138 | // In the 3-D mapping case, scale the first dimension by the multiplicity. |
139 | SmallVector<Value> scaledIds = ids; |
140 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
141 | scaledIds[0] = cast<Value>(Val: affine::makeComposedFoldedAffineApply( |
142 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {scaledIds[0]})); |
143 | // In the 3-D mapping case, unscale the first dimension by the multiplicity. |
144 | SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes); |
145 | forallMappingSizeInOriginalBasis[0] *= multiplicity; |
146 | return IdBuilderResult{ |
147 | /*mappingIdOps=*/scaledIds, |
148 | /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis}, |
149 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
150 | // scaled back into the original basis to provide tight |
151 | // activeMappingSizes quantities for predication. |
152 | /*activeMappingSizes=*/ |
153 | SmallVector<int64_t>{forallMappingSizeInOriginalBasis}, |
154 | /*activeIdOps=*/ids}; |
155 | }; |
156 | return res; |
157 | } |
158 | |
159 | namespace mlir { |
160 | namespace transform { |
161 | namespace gpu { |
162 | |
163 | GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping, |
164 | const MappingIdBuilderFnType &fn) |
165 | : mappingAttributes(), idBuilder() { |
166 | if (useLinearMapping) { |
167 | for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0), |
168 | e = getMaxEnumValForMappingId(); |
169 | d <= e; ++d) |
170 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
171 | } else { |
172 | for (uint64_t d = static_cast<uint64_t>(MappingId::DimX), |
173 | e = static_cast<uint64_t>(MappingId::DimZ); |
174 | d <= e; ++d) |
175 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
176 | } |
177 | } |
178 | |
179 | GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
180 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
181 | return GPUBlockMappingAttr::get(ctx, id); |
182 | }) { |
183 | idBuilder = useLinearMapping |
184 | ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1) |
185 | : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1); |
186 | } |
187 | |
188 | GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, |
189 | bool useLinearMapping) |
190 | : GpuIdBuilder(ctx, useLinearMapping, |
191 | [](MLIRContext *ctx, MappingId id) { |
192 | return GPUWarpgroupMappingAttr::get(ctx, id); |
193 | }), |
194 | warpSize(warpSize) { |
195 | idBuilder = useLinearMapping |
196 | ? commonLinearIdBuilderFn<ThreadIdOp>( |
197 | /*multiplicity=*/kNumWarpsPerGroup * warpSize) |
198 | : common3DIdBuilderFn<ThreadIdOp>( |
199 | /*multiplicity=*/kNumWarpsPerGroup * warpSize); |
200 | } |
201 | |
202 | GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, |
203 | bool useLinearMapping) |
204 | : GpuIdBuilder(ctx, useLinearMapping, |
205 | [](MLIRContext *ctx, MappingId id) { |
206 | return GPUWarpMappingAttr::get(ctx, id); |
207 | }), |
208 | warpSize(warpSize) { |
209 | idBuilder = |
210 | useLinearMapping |
211 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize) |
212 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); |
213 | } |
214 | |
215 | GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
216 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
217 | return GPUThreadMappingAttr::get(ctx, id); |
218 | }) { |
219 | idBuilder = useLinearMapping |
220 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1) |
221 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); |
222 | } |
223 | |
224 | DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, |
225 | std::optional<int64_t> gridDimX, |
226 | std::optional<int64_t> gridDimY, |
227 | std::optional<int64_t> gridDimZ, |
228 | std::optional<int64_t> blockDimX, |
229 | std::optional<int64_t> blockDimY, |
230 | std::optional<int64_t> blockDimZ) { |
231 | |
232 | // TODO: pass a configuration object to set the limits properly. |
233 | |
234 | if ((blockDimX.value_or(u: 1) * blockDimY.value_or(u: 1) * blockDimZ.value_or(u: 1)) > |
235 | kMaxTotalBlockdim || |
236 | (gridDimX.value_or(u: 1) * gridDimY.value_or(u: 1) * gridDimZ.value_or(u: 1)) > |
237 | kMaxTotalGriddim || |
238 | blockDimX.value_or(u: 1) > kMaxBlockdimx || |
239 | blockDimY.value_or(u: 1) > kMaxBlockdimy || |
240 | blockDimZ.value_or(u: 1) > kMaxBlockdimz || |
241 | gridDimY.value_or(u: 1) > kMaxGriddimy || |
242 | gridDimZ.value_or(u: 1) > kMaxGriddimz || |
243 | gridDimX.value_or(u: 1) > kMaxGriddimx) { |
244 | return transformOp.emitSilenceableError() |
245 | << "Trying to launch a GPU kernel with grid_dims = (" |
246 | << gridDimX.value_or(u: 1) << ", " << gridDimY.value_or(u: 1) << ", " |
247 | << gridDimZ.value_or(u: 1) << ") block_dims = (" |
248 | << blockDimX.value_or(u: 1) << ", " << blockDimY.value_or(u: 1) << ", " |
249 | << blockDimZ.value_or(u: 1) << "). It is larger than the limits." ; |
250 | } |
251 | return DiagnosedSilenceableFailure::success(); |
252 | } |
253 | |
254 | DiagnosedSilenceableFailure createGpuLaunch( |
255 | RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, |
256 | LaunchOp &launchOp, std::optional<int64_t> gridDimX, |
257 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
258 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
259 | std::optional<int64_t> blockDimZ) { |
260 | DiagnosedSilenceableFailure diag = |
261 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
262 | blockDimY, blockDimZ); |
263 | if (!diag.succeeded()) |
264 | return diag; |
265 | |
266 | auto createConst = [&](int dim) { |
267 | return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: dim); |
268 | }; |
269 | OpBuilder::InsertionGuard guard(rewriter); |
270 | Value one = createConst(1); |
271 | Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one; |
272 | Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one; |
273 | Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one; |
274 | Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; |
275 | Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; |
276 | Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; |
277 | launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ, |
278 | blkSizeX, blkSizeY, blkSizeZ); |
279 | rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); |
280 | rewriter.create<TerminatorOp>(loc); |
281 | return DiagnosedSilenceableFailure::success(); |
282 | } |
283 | |
284 | /// Alter kernel configuration of the given kernel. |
285 | DiagnosedSilenceableFailure alterGpuLaunch( |
286 | RewriterBase &rewriter, LaunchOp gpuLaunch, |
287 | TransformOpInterface transformOp, std::optional<int64_t> gridDimX, |
288 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
289 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
290 | std::optional<int64_t> blockDimZ) { |
291 | DiagnosedSilenceableFailure diag = |
292 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
293 | blockDimY, blockDimZ); |
294 | if (!diag.succeeded()) |
295 | return diag; |
296 | |
297 | KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); |
298 | OpBuilder::InsertionGuard guard(rewriter); |
299 | rewriter.setInsertionPointAfterValue(currentBlockdim.x); |
300 | auto createConstValue = [&](int dim) { |
301 | return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(), |
302 | dim); |
303 | }; |
304 | |
305 | if (gridDimX.has_value()) |
306 | gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value())); |
307 | if (gridDimY.has_value()) |
308 | gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value())); |
309 | if (gridDimZ.has_value()) |
310 | gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value())); |
311 | if (blockDimX.has_value()) |
312 | gpuLaunch.getBlockSizeXMutable().assign( |
313 | createConstValue(blockDimX.value())); |
314 | if (blockDimY.has_value()) |
315 | gpuLaunch.getBlockSizeYMutable().assign( |
316 | createConstValue(blockDimY.value())); |
317 | if (blockDimZ.has_value()) |
318 | gpuLaunch.getBlockSizeZMutable().assign( |
319 | createConstValue(blockDimZ.value())); |
320 | return DiagnosedSilenceableFailure::success(); |
321 | } |
322 | |
323 | } // namespace gpu |
324 | } // namespace transform |
325 | } // namespace mlir |
326 | |