1 | //===- Utils.cpp - Utils for GPU transform ops ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
18 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
20 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
23 | #include "mlir/IR/AffineExpr.h" |
24 | #include "mlir/IR/Builders.h" |
25 | #include "mlir/IR/BuiltinAttributes.h" |
26 | #include "mlir/IR/IRMapping.h" |
27 | #include "mlir/IR/MLIRContext.h" |
28 | #include "mlir/IR/OpDefinition.h" |
29 | #include "mlir/IR/Value.h" |
30 | #include "mlir/IR/Visitors.h" |
31 | #include "mlir/Support/LLVM.h" |
32 | #include "llvm/ADT/STLExtras.h" |
33 | #include "llvm/ADT/SmallVector.h" |
34 | #include "llvm/ADT/TypeSwitch.h" |
35 | #include "llvm/Support/Debug.h" |
36 | |
37 | using namespace mlir; |
38 | using namespace mlir::gpu; |
39 | using namespace mlir::transform; |
40 | using namespace mlir::transform::gpu; |
41 | |
42 | #define DEBUG_TYPE "gpu-transforms" |
43 | |
44 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
45 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
46 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
47 | |
48 | /// Return a flattened thread id for the workgroup with given sizes. |
49 | template <typename ThreadOrBlockIdOp> |
50 | static Value buildLinearId(RewriterBase &rewriter, Location loc, |
51 | ArrayRef<OpFoldResult> originalBasisOfr) { |
52 | LLVM_DEBUG(llvm::interleaveComma( |
53 | originalBasisOfr, |
54 | DBGS() << "----buildLinearId with originalBasisOfr: " ); |
55 | llvm::dbgs() << "\n" ); |
56 | assert(originalBasisOfr.size() == 3 && "expected 3 sizes" ); |
57 | IndexType indexType = rewriter.getIndexType(); |
58 | AffineExpr tx, ty, tz, bdx, bdy; |
59 | bindDims(ctx: rewriter.getContext(), exprs&: tx, exprs&: ty, exprs&: tz); |
60 | bindSymbols(ctx: rewriter.getContext(), exprs&: bdx, exprs&: bdy); |
61 | SmallVector<OpFoldResult> vals{ |
62 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x) |
63 | .getResult(), |
64 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y) |
65 | .getResult(), |
66 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z) |
67 | .getResult(), |
68 | originalBasisOfr[0], originalBasisOfr[1]}; |
69 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
70 | b&: rewriter, loc, expr: tx + ty * bdx + tz * bdx * bdy, operands: vals); |
71 | return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr); |
72 | } |
73 | |
74 | /// Create a linear id builder that takes the `originalBasisOfr` and decompose |
75 | /// it in the basis of `forallMappingSizes`. The linear id builder returns an |
76 | /// n-D vector of ids for indexing and 1-D size + id for predicate generation. |
77 | template <typename ThreadOrBlockIdOp> |
78 | static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { |
79 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
80 | ArrayRef<int64_t> forallMappingSizes, |
81 | ArrayRef<int64_t> originalBasis) { |
82 | SmallVector<OpFoldResult> originalBasisOfr = |
83 | getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: originalBasis); |
84 | OpFoldResult linearId = |
85 | buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr); |
86 | // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in |
87 | // "row-major" order. |
88 | SmallVector<int64_t> reverseBasisSizes(llvm::reverse(C&: forallMappingSizes)); |
89 | SmallVector<int64_t> strides = computeStrides(sizes: reverseBasisSizes); |
90 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
91 | OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply( |
92 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {linearId}); |
93 | SmallVector<AffineExpr> delinearizingExprs = delinearize(linearIndex: d0, strides); |
94 | SmallVector<Value> ids; |
95 | // Reverse back to be in [0 .. n] order. |
96 | for (AffineExpr e : llvm::reverse(C&: delinearizingExprs)) { |
97 | ids.push_back( |
98 | Elt: affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId})); |
99 | } |
100 | |
101 | // clang-format off |
102 | LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes, |
103 | DBGS() << "--delinearization basis: " ); |
104 | llvm::dbgs() << "\n" ; |
105 | llvm::interleaveComma(strides, |
106 | DBGS() << "--delinearization strides: " ); |
107 | llvm::dbgs() << "\n" ; |
108 | llvm::interleaveComma(delinearizingExprs, |
109 | DBGS() << "--delinearization exprs: " ); |
110 | llvm::dbgs() << "\n" ; |
111 | llvm::interleaveComma(ids, DBGS() << "--ids: " ); |
112 | llvm::dbgs() << "\n" ;); |
113 | // clang-format on |
114 | |
115 | // Return n-D ids for indexing and 1-D size + id for predicate generation. |
116 | return IdBuilderResult{ |
117 | /*mappingIdOps=*/ids, |
118 | /*availableMappingSizes=*/ |
119 | SmallVector<int64_t>{computeProduct(basis: originalBasis)}, |
120 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
121 | // scaled back into the original basis to provide tight |
122 | // activeMappingSizes quantities for predication. |
123 | /*activeMappingSizes=*/ |
124 | SmallVector<int64_t>{computeProduct(basis: forallMappingSizes) * multiplicity}, |
125 | /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}}; |
126 | }; |
127 | |
128 | return res; |
129 | } |
130 | |
131 | /// Create a simple 3-D id builder that takes the `originalBasisOfr` |
132 | /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes |
133 | /// + ids for predicate generation. |
134 | template <typename ThreadOrBlockIdOp> |
135 | static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) { |
136 | auto res = [multiplicity](RewriterBase &rewriter, Location loc, |
137 | ArrayRef<int64_t> forallMappingSizes, |
138 | ArrayRef<int64_t> originalBasis) { |
139 | IndexType indexType = rewriter.getIndexType(); |
140 | SmallVector<Value> ids{ |
141 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x), |
142 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y), |
143 | rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)}; |
144 | // In the 3-D mapping case, scale the first dimension by the multiplicity. |
145 | SmallVector<Value> scaledIds = ids; |
146 | AffineExpr d0 = getAffineDimExpr(position: 0, context: rewriter.getContext()); |
147 | scaledIds[0] = affine::makeComposedFoldedAffineApply( |
148 | b&: rewriter, loc, expr: d0.floorDiv(v: multiplicity), operands: {scaledIds[0]}) |
149 | .get<Value>(); |
150 | // In the 3-D mapping case, unscale the first dimension by the multiplicity. |
151 | SmallVector<int64_t> forallMappingSizeInOriginalBasis( |
152 | forallMappingSizes.begin(), forallMappingSizes.end()); |
153 | forallMappingSizeInOriginalBasis[0] *= multiplicity; |
154 | return IdBuilderResult{ |
155 | /*mappingIdOps=*/scaledIds, |
156 | /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis}, |
157 | // `forallMappingSizes` iterate in the scaled basis, they need to be |
158 | // scaled back into the original basis to provide tight |
159 | // activeMappingSizes quantities for predication. |
160 | /*activeMappingSizes=*/ |
161 | SmallVector<int64_t>{forallMappingSizeInOriginalBasis}, |
162 | /*activeIdOps=*/ids}; |
163 | }; |
164 | return res; |
165 | } |
166 | |
167 | namespace mlir { |
168 | namespace transform { |
169 | namespace gpu { |
170 | |
171 | GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping, |
172 | const MappingIdBuilderFnType &fn) |
173 | : mappingAttributes(), idBuilder() { |
174 | if (useLinearMapping) { |
175 | for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0), |
176 | e = getMaxEnumValForMappingId(); |
177 | d <= e; ++d) |
178 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
179 | } else { |
180 | for (uint64_t d = static_cast<uint64_t>(MappingId::DimX), |
181 | e = static_cast<uint64_t>(MappingId::DimZ); |
182 | d <= e; ++d) |
183 | mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); |
184 | } |
185 | } |
186 | |
187 | GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
188 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
189 | return GPUBlockMappingAttr::get(ctx, id); |
190 | }) { |
191 | idBuilder = useLinearMapping |
192 | ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1) |
193 | : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1); |
194 | } |
195 | |
196 | GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, |
197 | bool useLinearMapping) |
198 | : GpuIdBuilder(ctx, useLinearMapping, |
199 | [](MLIRContext *ctx, MappingId id) { |
200 | return GPUWarpgroupMappingAttr::get(ctx, id); |
201 | }), |
202 | warpSize(warpSize) { |
203 | idBuilder = useLinearMapping |
204 | ? commonLinearIdBuilderFn<ThreadIdOp>( |
205 | /*multiplicity=*/kNumWarpsPerGroup * warpSize) |
206 | : common3DIdBuilderFn<ThreadIdOp>( |
207 | /*multiplicity=*/kNumWarpsPerGroup * warpSize); |
208 | } |
209 | |
210 | GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, |
211 | bool useLinearMapping) |
212 | : GpuIdBuilder(ctx, useLinearMapping, |
213 | [](MLIRContext *ctx, MappingId id) { |
214 | return GPUWarpMappingAttr::get(ctx, id); |
215 | }), |
216 | warpSize(warpSize) { |
217 | idBuilder = |
218 | useLinearMapping |
219 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize) |
220 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); |
221 | } |
222 | |
223 | GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping) |
224 | : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { |
225 | return GPUThreadMappingAttr::get(ctx, id); |
226 | }) { |
227 | idBuilder = useLinearMapping |
228 | ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1) |
229 | : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); |
230 | } |
231 | |
232 | DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, |
233 | std::optional<int64_t> gridDimX, |
234 | std::optional<int64_t> gridDimY, |
235 | std::optional<int64_t> gridDimZ, |
236 | std::optional<int64_t> blockDimX, |
237 | std::optional<int64_t> blockDimY, |
238 | std::optional<int64_t> blockDimZ) { |
239 | |
240 | // TODO: pass a configuration object to set the limits properly. |
241 | static constexpr int maxTotalBlockdim = 1024; |
242 | static constexpr int maxBlockdimx = 1024; |
243 | static constexpr int maxBlockdimy = 1024; |
244 | static constexpr int maxBlockdimz = 64; |
245 | static constexpr int maxTotalGriddim = 2147483647; |
246 | static constexpr int maxGriddimx = 2147483647; |
247 | static constexpr int maxGriddimy = 65535; |
248 | static constexpr int maxGriddimz = 65535; |
249 | |
250 | if ((blockDimX.value_or(u: 1) * blockDimY.value_or(u: 1) * blockDimZ.value_or(u: 1)) > |
251 | maxTotalBlockdim || |
252 | (gridDimX.value_or(u: 1) * gridDimY.value_or(u: 1) * gridDimZ.value_or(u: 1)) > |
253 | maxTotalGriddim || |
254 | blockDimX.value_or(u: 1) > maxBlockdimx || |
255 | blockDimY.value_or(u: 1) > maxBlockdimy || |
256 | blockDimZ.value_or(u: 1) > maxBlockdimz || |
257 | gridDimY.value_or(u: 1) > maxGriddimy || |
258 | gridDimZ.value_or(u: 1) > maxGriddimz || |
259 | gridDimX.value_or(u: 1) > maxGriddimx) { |
260 | return transformOp.emitSilenceableError() |
261 | << "Trying to launch a GPU kernel with grid_dims = (" |
262 | << gridDimX.value_or(u: 1) << ", " << gridDimY.value_or(u: 1) << ", " |
263 | << gridDimZ.value_or(u: 1) << ") block_dims = (" |
264 | << blockDimX.value_or(u: 1) << ", " << blockDimY.value_or(u: 1) << ", " |
265 | << blockDimZ.value_or(u: 1) << "). It is larger than the limits." ; |
266 | } |
267 | return DiagnosedSilenceableFailure::success(); |
268 | } |
269 | |
270 | DiagnosedSilenceableFailure createGpuLaunch( |
271 | RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, |
272 | LaunchOp &launchOp, std::optional<int64_t> gridDimX, |
273 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
274 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
275 | std::optional<int64_t> blockDimZ) { |
276 | DiagnosedSilenceableFailure diag = |
277 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
278 | blockDimY, blockDimZ); |
279 | if (!diag.succeeded()) |
280 | return diag; |
281 | |
282 | auto createConst = [&](int dim) { |
283 | return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: dim); |
284 | }; |
285 | OpBuilder::InsertionGuard guard(rewriter); |
286 | Value one = createConst(1); |
287 | Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one; |
288 | Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one; |
289 | Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one; |
290 | Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; |
291 | Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; |
292 | Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; |
293 | launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ, |
294 | blkSizeX, blkSizeY, blkSizeZ); |
295 | rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); |
296 | rewriter.create<TerminatorOp>(loc); |
297 | return DiagnosedSilenceableFailure::success(); |
298 | } |
299 | |
300 | /// Alter kernel configuration of the given kernel. |
301 | DiagnosedSilenceableFailure alterGpuLaunch( |
302 | RewriterBase &rewriter, LaunchOp gpuLaunch, |
303 | TransformOpInterface transformOp, std::optional<int64_t> gridDimX, |
304 | std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, |
305 | std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, |
306 | std::optional<int64_t> blockDimZ) { |
307 | DiagnosedSilenceableFailure diag = |
308 | checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, |
309 | blockDimY, blockDimZ); |
310 | if (!diag.succeeded()) |
311 | return diag; |
312 | |
313 | KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); |
314 | OpBuilder::InsertionGuard guard(rewriter); |
315 | rewriter.setInsertionPointAfterValue(currentBlockdim.x); |
316 | auto createConstValue = [&](int dim) { |
317 | return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(), |
318 | dim); |
319 | }; |
320 | |
321 | if (gridDimX.has_value()) |
322 | gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value())); |
323 | if (gridDimY.has_value()) |
324 | gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value())); |
325 | if (gridDimZ.has_value()) |
326 | gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value())); |
327 | if (blockDimX.has_value()) |
328 | gpuLaunch.getBlockSizeXMutable().assign( |
329 | createConstValue(blockDimX.value())); |
330 | if (blockDimY.has_value()) |
331 | gpuLaunch.getBlockSizeYMutable().assign( |
332 | createConstValue(blockDimY.value())); |
333 | if (blockDimZ.has_value()) |
334 | gpuLaunch.getBlockSizeZMutable().assign( |
335 | createConstValue(blockDimZ.value())); |
336 | return DiagnosedSilenceableFailure::success(); |
337 | } |
338 | |
339 | } // namespace gpu |
340 | } // namespace transform |
341 | } // namespace mlir |
342 | |