1 | //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
10 | |
11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
12 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
20 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
21 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
23 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
26 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
27 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
29 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
30 | #include "mlir/IR/AffineExpr.h" |
31 | #include "mlir/IR/Builders.h" |
32 | #include "mlir/IR/BuiltinAttributes.h" |
33 | #include "mlir/IR/IRMapping.h" |
34 | #include "mlir/IR/MLIRContext.h" |
35 | #include "mlir/IR/OpDefinition.h" |
36 | #include "mlir/IR/Visitors.h" |
37 | #include "mlir/Support/LLVM.h" |
38 | #include "mlir/Transforms/DialectConversion.h" |
39 | #include "llvm/ADT/STLExtras.h" |
40 | #include "llvm/ADT/SmallVector.h" |
41 | #include "llvm/ADT/TypeSwitch.h" |
42 | #include "llvm/Support/Debug.h" |
43 | #include "llvm/Support/ErrorHandling.h" |
44 | #include "llvm/Support/InterleavedRange.h" |
45 | #include <type_traits> |
46 | |
47 | using namespace mlir; |
48 | using namespace mlir::gpu; |
49 | using namespace mlir::transform; |
50 | using namespace mlir::transform::gpu; |
51 | |
52 | #define DEBUG_TYPE "gpu-transforms" |
53 | #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" |
54 | |
55 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
56 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
57 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // Apply...ConversionPatternsOp |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( |
64 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
65 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
66 | // NVVM uses alloca in the default address space to represent private |
67 | // memory allocations, so drop private annotations. NVVM uses address |
68 | // space 3 for shared memory. NVVM uses the default address space to |
69 | // represent global memory. |
70 | // Used in populateGpuToNVVMConversionPatternsso attaching here for now. |
71 | // TODO: We should have a single to_nvvm_type_converter. |
72 | populateGpuMemorySpaceAttributeConversions( |
73 | llvmTypeConverter, [](AddressSpace space) -> unsigned { |
74 | switch (space) { |
75 | case AddressSpace::Global: |
76 | return static_cast<unsigned>( |
77 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
78 | case AddressSpace::Workgroup: |
79 | return static_cast<unsigned>( |
80 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
81 | case AddressSpace::Private: |
82 | return 0; |
83 | } |
84 | llvm_unreachable("unknown address space enum value"); |
85 | return 0; |
86 | }); |
87 | // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. |
88 | // TODO: We should have a single to_nvvm_type_converter. |
89 | llvmTypeConverter.addConversion( |
90 | [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); |
91 | // Set higher benefit, so patterns will run before generic LLVM lowering. |
92 | populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns, |
93 | getBenefit()); |
94 | } |
95 | |
96 | LogicalResult |
97 | transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
98 | transform::TypeConverterBuilderOpInterface builder) { |
99 | if (builder.getTypeConverterType() != "LLVMTypeConverter") |
100 | return emitOpError("expected LLVMTypeConverter"); |
101 | return success(); |
102 | } |
103 | |
104 | void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( |
105 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
106 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
107 | populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns); |
108 | } |
109 | |
110 | LogicalResult |
111 | transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( |
112 | transform::TypeConverterBuilderOpInterface builder) { |
113 | if (builder.getTypeConverterType() != "LLVMTypeConverter") |
114 | return emitOpError("expected LLVMTypeConverter"); |
115 | return success(); |
116 | } |
117 | |
118 | void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
119 | populatePatterns(TypeConverter &typeConverter, |
120 | RewritePatternSet &patterns) { |
121 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
122 | populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns); |
123 | } |
124 | |
125 | LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
126 | verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { |
127 | if (builder.getTypeConverterType() != "LLVMTypeConverter") |
128 | return emitOpError("expected LLVMTypeConverter"); |
129 | return success(); |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // Apply...PatternsOp |
134 | //===----------------------------------------------------------------------===//s |
135 | |
136 | void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { |
137 | populateGpuRewritePatterns(patterns); |
138 | } |
139 | |
140 | void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns( |
141 | RewritePatternSet &patterns) { |
142 | populateGpuPromoteShuffleToAMDGPUPatterns(patterns); |
143 | } |
144 | |
145 | //===----------------------------------------------------------------------===// |
146 | // ApplyUnrollVectorsSubgroupMmaOp |
147 | //===----------------------------------------------------------------------===// |
148 | |
149 | /// Pick an unrolling order that will allow tensorcore operation to reuse LHS |
150 | /// register. |
151 | static std::optional<SmallVector<int64_t>> |
152 | gpuMmaUnrollOrder(vector::ContractionOp contract) { |
153 | SmallVector<int64_t> order; |
154 | // First make reduction the outer dimensions. |
155 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
156 | if (vector::isReductionIterator(iter)) { |
157 | order.push_back(index); |
158 | } |
159 | } |
160 | |
161 | llvm::SmallDenseSet<int64_t> dims; |
162 | for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { |
163 | dims.insert(cast<AffineDimExpr>(expr).getPosition()); |
164 | } |
165 | // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. |
166 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
167 | if (vector::isParallelIterator(iter) && dims.count(index)) { |
168 | order.push_back(index); |
169 | } |
170 | } |
171 | // Then the remaining parallel loops. |
172 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
173 | if (vector::isParallelIterator(iter) && !dims.count(index)) { |
174 | order.push_back(index); |
175 | } |
176 | } |
177 | return order; |
178 | } |
179 | |
180 | /// Returns the target vector size for the target operation based on the native |
181 | /// vector size specified with `m`, `n`, and `k`. |
182 | static std::optional<SmallVector<int64_t>> |
183 | getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { |
184 | if (auto contract = dyn_cast<vector::ContractionOp>(op)) { |
185 | int64_t contractRank = contract.getIteratorTypes().size(); |
186 | if (contractRank < 3) |
187 | return std::nullopt; |
188 | SmallVector<int64_t> nativeSize(contractRank - 3, 1); |
189 | nativeSize.append(IL: {m, n, k}); |
190 | return nativeSize; |
191 | } |
192 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
193 | int64_t writeRank = writeOp.getVectorType().getRank(); |
194 | if (writeRank < 2) |
195 | return std::nullopt; |
196 | SmallVector<int64_t> nativeSize(writeRank - 2, 1); |
197 | nativeSize.append(IL: {m, n}); |
198 | return nativeSize; |
199 | } |
200 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { |
201 | // Transfer read ops may need different shapes based on how they are being |
202 | // used. For simplicity just match the shape used by the extract strided op. |
203 | VectorType sliceType; |
204 | for (Operation *users : op->getUsers()) { |
205 | auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); |
206 | if (!extract) |
207 | return std::nullopt; |
208 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
209 | if (sliceType && sliceType != vecType) |
210 | return std::nullopt; |
211 | sliceType = vecType; |
212 | } |
213 | return llvm::to_vector(sliceType.getShape()); |
214 | } |
215 | if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { |
216 | if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { |
217 | // TODO: The condition for unrolling elementwise should be restricted |
218 | // only to operations that need unrolling (connected to the contract). |
219 | if (vecType.getRank() < 2) |
220 | return std::nullopt; |
221 | |
222 | // First check whether there is a slice to infer the shape from. This is |
223 | // required for cases where the accumulator type differs from the input |
224 | // types, in which case we will see an `arith.ext_` between the contract |
225 | // and transfer_read which needs to be unrolled. |
226 | VectorType sliceType; |
227 | for (Operation *users : op->getUsers()) { |
228 | auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); |
229 | if (!extract) |
230 | return std::nullopt; |
231 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
232 | if (sliceType && sliceType != vecType) |
233 | return std::nullopt; |
234 | sliceType = vecType; |
235 | } |
236 | if (sliceType) |
237 | return llvm::to_vector(sliceType.getShape()); |
238 | |
239 | // Else unroll for trailing elementwise. |
240 | SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); |
241 | // Map elementwise ops to the output shape. |
242 | nativeSize.append(IL: {m, n}); |
243 | return nativeSize; |
244 | } |
245 | } |
246 | return std::nullopt; |
247 | } |
248 | |
249 | void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( |
250 | RewritePatternSet &patterns) { |
251 | auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
252 | auto contract = dyn_cast<vector::ContractionOp>(op); |
253 | if (!contract) |
254 | return std::nullopt; |
255 | return gpuMmaUnrollOrder(contract); |
256 | }; |
257 | |
258 | int64_t m = getM(); |
259 | int64_t n = getN(); |
260 | int64_t k = getK(); |
261 | auto nativeShapeFn = |
262 | [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { |
263 | return getSubgroupMmaNativeVectorSize(op, m, n, k); |
264 | }; |
265 | vector::populateVectorUnrollPatterns( |
266 | patterns, vector::UnrollVectorOptions() |
267 | .setNativeShapeFn(nativeShapeFn) |
268 | .setUnrollTraversalOrderFn(unrollOrder)); |
269 | } |
270 | |
271 | //===----------------------------------------------------------------------===// |
272 | // EliminateBarriersOp |
273 | //===----------------------------------------------------------------------===// |
274 | |
275 | void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { |
276 | populateGpuEliminateBarriersPatterns(patterns); |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // Block and thread mapping utilities. |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | namespace { |
284 | /// Local types used for mapping verification. |
285 | struct MappingKind {}; |
286 | struct BlockMappingKind : MappingKind {}; |
287 | struct ThreadMappingKind : MappingKind {}; |
288 | } // namespace |
289 | |
290 | static DiagnosedSilenceableFailure |
291 | definiteFailureHelper(std::optional<TransformOpInterface> transformOp, |
292 | Operation *target, const Twine &message) { |
293 | if (transformOp.has_value()) |
294 | return transformOp->emitDefiniteFailure() << message; |
295 | return emitDefiniteFailure(op: target, message); |
296 | } |
297 | |
298 | /// Check if given mapping attributes are one of the desired attributes |
299 | template <typename MappingKindType> |
300 | static DiagnosedSilenceableFailure |
301 | checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, |
302 | scf::ForallOp forallOp) { |
303 | if (!forallOp.getMapping().has_value()) { |
304 | return definiteFailureHelper(transformOp, forallOp, |
305 | "scf.forall op requires a mapping attribute"); |
306 | } |
307 | |
308 | bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), |
309 | llvm::IsaPred<GPUBlockMappingAttr>); |
310 | bool hasWarpgroupMapping = llvm::any_of( |
311 | forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); |
312 | bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), |
313 | llvm::IsaPred<GPUWarpMappingAttr>); |
314 | bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), |
315 | llvm::IsaPred<GPUThreadMappingAttr>); |
316 | int64_t countMappingTypes = 0; |
317 | countMappingTypes += hasBlockMapping ? 1 : 0; |
318 | countMappingTypes += hasWarpgroupMapping ? 1 : 0; |
319 | countMappingTypes += hasWarpMapping ? 1 : 0; |
320 | countMappingTypes += hasThreadMapping ? 1 : 0; |
321 | if (countMappingTypes > 1) { |
322 | return definiteFailureHelper( |
323 | transformOp, forallOp, |
324 | "cannot mix different mapping types, use nesting"); |
325 | } |
326 | if (std::is_same<MappingKindType, BlockMappingKind>::value && |
327 | !hasBlockMapping) { |
328 | return definiteFailureHelper( |
329 | transformOp, forallOp, |
330 | "scf.forall op requires a mapping attribute of kind 'block'"); |
331 | } |
332 | if (std::is_same<MappingKindType, ThreadMappingKind>::value && |
333 | !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) { |
334 | return definiteFailureHelper(transformOp, forallOp, |
335 | "scf.forall op requires a mapping attribute " |
336 | "of kind 'thread' or 'warp'"); |
337 | } |
338 | |
339 | DenseSet<Attribute> seen; |
340 | for (Attribute map : forallOp.getMapping()->getValue()) { |
341 | if (seen.contains(map)) { |
342 | return definiteFailureHelper( |
343 | transformOp, forallOp, |
344 | "duplicate attribute, cannot map different loops " |
345 | "to the same mapping id"); |
346 | } |
347 | seen.insert(map); |
348 | } |
349 | |
350 | auto isLinear = [](Attribute a) { |
351 | return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); |
352 | }; |
353 | if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && |
354 | !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { |
355 | return definiteFailureHelper( |
356 | transformOp, forallOp, |
357 | "cannot mix linear and non-linear mapping modes"); |
358 | } |
359 | |
360 | return DiagnosedSilenceableFailure::success(); |
361 | } |
362 | |
363 | template <typename MappingKindType> |
364 | static DiagnosedSilenceableFailure |
365 | verifyGpuMapping(std::optional<TransformOpInterface> transformOp, |
366 | scf::ForallOp forallOp) { |
367 | // Check the types of the mapping attributes match. |
368 | DiagnosedSilenceableFailure typeRes = |
369 | checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); |
370 | if (!typeRes.succeeded()) |
371 | return typeRes; |
372 | |
373 | // Perform other non-types verifications. |
374 | if (!forallOp.isNormalized()) |
375 | return definiteFailureHelper(transformOp, forallOp, |
376 | "unsupported non-normalized loops"); |
377 | if (forallOp.getNumResults() > 0) |
378 | return definiteFailureHelper(transformOp, forallOp, |
379 | "only bufferized scf.forall can be mapped"); |
380 | bool useLinearMapping = cast<DeviceMappingAttrInterface>( |
381 | forallOp.getMapping()->getValue().front()) |
382 | .isLinearMapping(); |
383 | // TODO: This would be more natural with support for Optional<EnumParameter> |
384 | // in GPUDeviceMappingAttr. |
385 | int64_t maxNumMappingsSupported = |
386 | useLinearMapping ? (getMaxEnumValForMappingId() - |
387 | static_cast<uint64_t>(MappingId::DimZ)) |
388 | : 3; |
389 | if (forallOp.getRank() > maxNumMappingsSupported) { |
390 | return definiteFailureHelper(transformOp, forallOp, |
391 | "scf.forall with rank > ") |
392 | << maxNumMappingsSupported |
393 | << " does not lower for the specified mapping attribute type"; |
394 | } |
395 | auto numParallelIterations = |
396 | getConstantIntValues(forallOp.getMixedUpperBound()); |
397 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
398 | return definiteFailureHelper( |
399 | transformOp, forallOp, |
400 | "requires statically sized, normalized forall op"); |
401 | } |
402 | return DiagnosedSilenceableFailure::success(); |
403 | } |
404 | |
405 | /// Struct to return the result of the rewrite of a forall operation. |
406 | struct ForallRewriteResult { |
407 | SmallVector<int64_t> mappingSizes; |
408 | SmallVector<Value> mappingIds; |
409 | }; |
410 | |
411 | /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. |
412 | template <typename OpTy, typename OperationOrBlock> |
413 | static void |
414 | replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, |
415 | OperationOrBlock *parent, Value replacement, |
416 | ArrayRef<int64_t> availableMappingSizes) { |
417 | parent->walk([&](OpTy idOp) { |
418 | if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) |
419 | rewriter.replaceAllUsesWith(idOp.getResult(), replacement); |
420 | }); |
421 | } |
422 | |
423 | static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( |
424 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
425 | scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, |
426 | ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { |
427 | LDBG("--start rewriteOneForallCommonImpl"); |
428 | |
429 | // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. |
430 | auto numParallelIterations = |
431 | getConstantIntValues(forallOp.getMixedUpperBound()); |
432 | assert(forallOp.isNormalized() && numParallelIterations.has_value() && |
433 | "requires statically sized, normalized forall op"); |
434 | SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); |
435 | SetVector<Attribute> forallMappingAttrs; |
436 | forallMappingAttrs.insert_range(forallOp.getMapping()->getValue()); |
437 | auto comparator = [](Attribute a, Attribute b) -> bool { |
438 | return cast<DeviceMappingAttrInterface>(a).getMappingId() < |
439 | cast<DeviceMappingAttrInterface>(b).getMappingId(); |
440 | }; |
441 | |
442 | // Step 1.b. In the linear case, compute the max mapping to avoid needlessly |
443 | // mapping all dimensions. In the 3-D mapping case we need to map all |
444 | // dimensions. |
445 | DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( |
446 | *llvm::max_element(forallMappingAttrs, comparator)); |
447 | DeviceMappingAttrInterface maxLinearMapping; |
448 | if (maxMapping.isLinearMapping()) |
449 | maxLinearMapping = maxMapping; |
450 | for (auto attr : gpuIdBuilder.mappingAttributes) { |
451 | // If attr overflows, just skip. |
452 | if (maxLinearMapping && comparator(maxLinearMapping, attr)) |
453 | continue; |
454 | // Try to insert. If element was already present, just continue. |
455 | if (!forallMappingAttrs.insert(attr)) |
456 | continue; |
457 | // Otherwise, we have a new insertion without a size -> use size 1. |
458 | tmpMappingSizes.push_back(1); |
459 | } |
460 | LDBG("----tmpMappingSizes extracted from scf.forall op: " |
461 | << llvm::interleaved(tmpMappingSizes)); |
462 | |
463 | // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. |
464 | SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( |
465 | keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator); |
466 | LDBG("----forallMappingSizes: "<< llvm::interleaved(forallMappingSizes)); |
467 | LDBG("----forallMappingAttrs: "<< llvm::interleaved(forallMappingAttrs)); |
468 | |
469 | // Step 3. Generate the mappingIdOps using the provided generator. |
470 | Location loc = forallOp.getLoc(); |
471 | OpBuilder::InsertionGuard guard(rewriter); |
472 | rewriter.setInsertionPoint(forallOp); |
473 | SmallVector<int64_t> originalBasis(availableMappingSizes); |
474 | bool originalBasisWasProvided = !originalBasis.empty(); |
475 | if (!originalBasisWasProvided) { |
476 | originalBasis = forallMappingSizes; |
477 | while (originalBasis.size() < 3) |
478 | originalBasis.push_back(Elt: 1); |
479 | } |
480 | |
481 | IdBuilderResult builderResult = |
482 | gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); |
483 | |
484 | // Step 4. Map the induction variables to the mappingIdOps, this may involve |
485 | // a permutation. |
486 | SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; |
487 | IRMapping bvm; |
488 | for (auto [iv, dim] : llvm::zip_equal( |
489 | forallOp.getInductionVars(), |
490 | forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { |
491 | auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); |
492 | Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; |
493 | bvm.map(iv, peIdOp); |
494 | } |
495 | |
496 | // Step 5. If the originalBasis is already known, create conditionals to |
497 | // predicate the region. Otherwise, the current forall determines the |
498 | // originalBasis and no predication occurs. |
499 | Value predicate; |
500 | if (originalBasisWasProvided) { |
501 | SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes; |
502 | SmallVector<int64_t> availableMappingSizes = |
503 | builderResult.availableMappingSizes; |
504 | SmallVector<Value> activeIdOps = builderResult.activeIdOps; |
505 | LDBG("----activeMappingSizes: "<< llvm::interleaved(activeMappingSizes)); |
506 | LDBG("----availableMappingSizes: " |
507 | << llvm::interleaved(availableMappingSizes)); |
508 | LDBG("----activeIdOps: "<< llvm::interleaved(activeIdOps)); |
509 | for (auto [activeId, activeMappingSize, availableMappingSize] : |
510 | llvm::zip_equal(activeIdOps, activeMappingSizes, |
511 | availableMappingSizes)) { |
512 | if (activeMappingSize > availableMappingSize) { |
513 | return definiteFailureHelper( |
514 | transformOp, forallOp, |
515 | "Trying to map to fewer GPU threads than loop iterations but " |
516 | "overprovisioning is not yet supported. " |
517 | "Try additional tiling of the before mapping or map to more " |
518 | "threads."); |
519 | } |
520 | if (activeMappingSize == availableMappingSize) |
521 | continue; |
522 | Value idx = |
523 | rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize); |
524 | Value tmpPredicate = rewriter.create<arith::CmpIOp>( |
525 | loc, arith::CmpIPredicate::ult, activeId, idx); |
526 | LDBG("----predicate: "<< tmpPredicate); |
527 | predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, |
528 | tmpPredicate) |
529 | : tmpPredicate; |
530 | } |
531 | } |
532 | |
533 | // Step 6. Move the body of forallOp. |
534 | // Erase the terminator first, it will not be used. |
535 | rewriter.eraseOp(op: forallOp.getTerminator()); |
536 | Block *targetBlock; |
537 | Block::iterator insertionPoint; |
538 | if (predicate) { |
539 | // Step 6.a. If predicated, move at the beginning. |
540 | auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, |
541 | /*withElseRegion=*/false); |
542 | targetBlock = ifOp.thenBlock(); |
543 | insertionPoint = ifOp.thenBlock()->begin(); |
544 | } else { |
545 | // Step 6.b. Otherwise, move inline just at the rewriter insertion |
546 | // point. |
547 | targetBlock = forallOp->getBlock(); |
548 | insertionPoint = rewriter.getInsertionPoint(); |
549 | } |
550 | Block &sourceBlock = forallOp.getRegion().front(); |
551 | targetBlock->getOperations().splice(where: insertionPoint, |
552 | L2&: sourceBlock.getOperations()); |
553 | |
554 | // Step 7. RAUW indices. |
555 | for (Value loopIndex : forallOp.getInductionVars()) { |
556 | Value threadIdx = bvm.lookup(loopIndex); |
557 | rewriter.replaceAllUsesWith(loopIndex, threadIdx); |
558 | } |
559 | |
560 | // Step 8. Erase old op. |
561 | rewriter.eraseOp(op: forallOp); |
562 | |
563 | LDBG("----result forallMappingSizes: " |
564 | << llvm::interleaved(forallMappingSizes)); |
565 | LDBG("----result mappingIdOps: "<< llvm::interleaved(mappingIdOps)); |
566 | |
567 | result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps}; |
568 | return DiagnosedSilenceableFailure::success(); |
569 | } |
570 | |
571 | //===----------------------------------------------------------------------===// |
572 | // MapForallToBlocks |
573 | //===----------------------------------------------------------------------===// |
574 | |
575 | DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( |
576 | RewriterBase &rewriter, TransformOpInterface transformOp, |
577 | scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, |
578 | const GpuIdBuilder &gpuIdBuilder) { |
579 | LDBG("Start mapForallToBlocksImpl"); |
580 | |
581 | { |
582 | // GPU-specific verifications. There is no better place to anchor |
583 | // those right now: the ForallOp is target-independent and the transform |
584 | // op does not apply to individual ForallOp. |
585 | DiagnosedSilenceableFailure diag = |
586 | verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); |
587 | if (!diag.succeeded()) |
588 | return diag; |
589 | } |
590 | |
591 | Location loc = forallOp.getLoc(); |
592 | Block *parentBlock = forallOp->getBlock(); |
593 | Value zero; |
594 | { |
595 | // Create an early zero index value for replacements and immediately reset |
596 | // the insertion point. |
597 | OpBuilder::InsertionGuard guard(rewriter); |
598 | rewriter.setInsertionPointToStart(parentBlock); |
599 | zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
600 | } |
601 | |
602 | ForallRewriteResult rewriteResult; |
603 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
604 | rewriter, transformOp, forallOp, |
605 | /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder); |
606 | |
607 | // Return if anything goes wrong, use silenceable failure as a match |
608 | // failure. |
609 | if (!diag.succeeded()) |
610 | return diag; |
611 | |
612 | // If gridDims was not provided already, set it from the return. |
613 | if (gridDims.empty()) { |
614 | gridDims = rewriteResult.mappingSizes; |
615 | while (gridDims.size() < 3) |
616 | gridDims.push_back(Elt: 1); |
617 | } |
618 | assert(gridDims.size() == 3 && "Need 3-D gridDims"); |
619 | |
620 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
621 | // Here, the result of mapping determines the available mapping sizes. |
622 | replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, |
623 | rewriteResult.mappingSizes); |
624 | |
625 | return DiagnosedSilenceableFailure::success(); |
626 | } |
627 | |
628 | DiagnosedSilenceableFailure |
629 | mlir::transform::gpu::findTopLevelForallOp(Operation *target, |
630 | scf::ForallOp &topLevelForallOp, |
631 | TransformOpInterface transformOp) { |
632 | auto walkResult = target->walk([&](scf::ForallOp forallOp) { |
633 | if (forallOp->getParentOfType<scf::ForallOp>()) |
634 | return WalkResult::advance(); |
635 | if (topLevelForallOp) |
636 | // TODO: Handle multiple forall if they are independent. |
637 | return WalkResult::interrupt(); |
638 | topLevelForallOp = forallOp; |
639 | return WalkResult::advance(); |
640 | }); |
641 | |
642 | if (walkResult.wasInterrupted() || !topLevelForallOp) |
643 | return transformOp.emitSilenceableError() |
644 | << "could not find a unique topLevel scf.forall"; |
645 | return DiagnosedSilenceableFailure::success(); |
646 | } |
647 | |
648 | DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( |
649 | transform::TransformRewriter &rewriter, Operation *target, |
650 | ApplyToEachResultList &results, transform::TransformState &state) { |
651 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
652 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
653 | |
654 | if (!getGenerateGpuLaunch() && !gpuLaunch) { |
655 | DiagnosedSilenceableFailure diag = |
656 | emitSilenceableError() |
657 | << "Given target is not gpu.launch, set `generate_gpu_launch` " |
658 | "attribute"; |
659 | diag.attachNote(target->getLoc()) << "when applied to this payload op"; |
660 | return diag; |
661 | } |
662 | |
663 | scf::ForallOp topLevelForallOp; |
664 | DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( |
665 | target, topLevelForallOp, transformOp); |
666 | if (!diag.succeeded()) { |
667 | diag.attachNote(target->getLoc()) << "when applied to this payload op"; |
668 | return diag; |
669 | } |
670 | assert(topLevelForallOp && "expect an scf.forall"); |
671 | |
672 | SmallVector<int64_t> gridDims{getGridDims()}; |
673 | if (!getGenerateGpuLaunch() && gridDims.size() != 3) |
674 | return transformOp.emitDefiniteFailure("transform require size-3 mapping"); |
675 | |
676 | OpBuilder::InsertionGuard guard(rewriter); |
677 | rewriter.setInsertionPoint(topLevelForallOp); |
678 | |
679 | // Generate gpu launch here and move the forall inside |
680 | if (getGenerateGpuLaunch()) { |
681 | DiagnosedSilenceableFailure diag = |
682 | createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); |
683 | if (!diag.succeeded()) |
684 | return diag; |
685 | |
686 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
687 | Operation *newForallOp = rewriter.clone(*topLevelForallOp); |
688 | rewriter.eraseOp(topLevelForallOp); |
689 | topLevelForallOp = cast<scf::ForallOp>(newForallOp); |
690 | } |
691 | |
692 | // The BlockIdBuilder adapts to whatever is thrown at it. |
693 | bool useLinearMapping = false; |
694 | if (topLevelForallOp.getMapping()) { |
695 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
696 | topLevelForallOp.getMapping()->getValue().front()); |
697 | useLinearMapping = mappingAttr.isLinearMapping(); |
698 | } |
699 | GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); |
700 | |
701 | diag = mlir::transform::gpu::mapForallToBlocksImpl( |
702 | rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); |
703 | if (!diag.succeeded()) |
704 | return diag; |
705 | |
706 | // Set the GPU launch configuration for the grid dims late, this is |
707 | // subject to IR inspection. |
708 | diag = alterGpuLaunch(rewriter, gpuLaunch, |
709 | cast<TransformOpInterface>(getOperation()), gridDims[0], |
710 | gridDims[1], gridDims[2]); |
711 | |
712 | results.push_back(gpuLaunch); |
713 | return diag; |
714 | } |
715 | |
716 | LogicalResult transform::MapForallToBlocks::verify() { |
717 | if (!getGridDims().empty() && getGridDims().size() != 3) { |
718 | return emitOpError() << "transform requires empty or size-3 grid_dims"; |
719 | } |
720 | return success(); |
721 | } |
722 | |
723 | //===----------------------------------------------------------------------===// |
724 | // MapNestedForallToThreads |
725 | //===----------------------------------------------------------------------===// |
726 | |
727 | static DiagnosedSilenceableFailure checkMappingSpec( |
728 | std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, |
729 | ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, |
730 | int factor, bool useLinearMapping = false) { |
731 | if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { |
732 | auto diag = definiteFailureHelper( |
733 | transformOp, forallOp, |
734 | Twine("3-D mapping: size of threadIdx.x must be a multiple of ") + |
735 | Twine(factor)); |
736 | return diag; |
737 | } |
738 | if (computeProduct(basis: numParallelIterations) * factor > |
739 | computeProduct(basis: blockOrGridSizes)) { |
740 | auto diag = definiteFailureHelper( |
741 | transformOp, forallOp, |
742 | Twine("the number of required parallel resources (blocks or " |
743 | "threads) ") + |
744 | Twine(computeProduct(basis: numParallelIterations) * factor) + |
745 | " overflows the number of available resources "+ |
746 | Twine(computeProduct(basis: blockOrGridSizes))); |
747 | return diag; |
748 | } |
749 | return DiagnosedSilenceableFailure::success(); |
750 | } |
751 | |
752 | static DiagnosedSilenceableFailure |
753 | getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, |
754 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, |
755 | int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { |
756 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
757 | forallOp.getMapping()->getValue().front()); |
758 | bool useLinearMapping = mappingAttr.isLinearMapping(); |
759 | |
760 | // Sanity checks that may result in runtime verification errors. |
761 | auto numParallelIterations = |
762 | getConstantIntValues((forallOp.getMixedUpperBound())); |
763 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
764 | return definiteFailureHelper( |
765 | transformOp, forallOp, |
766 | "requires statically sized, normalized forall op"); |
767 | } |
768 | int64_t factor = 1; |
769 | if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) { |
770 | factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; |
771 | } else if (isa<GPUWarpMappingAttr>(mappingAttr)) { |
772 | factor = warpSize; |
773 | } |
774 | DiagnosedSilenceableFailure diag = |
775 | checkMappingSpec(transformOp, forallOp, numParallelIterations.value(), |
776 | blockSizes, factor, useLinearMapping); |
777 | if (!diag.succeeded()) |
778 | return diag; |
779 | |
780 | // Start mapping. |
781 | MLIRContext *ctx = forallOp.getContext(); |
782 | gpuIdBuilder = |
783 | TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) |
784 | .Case([&](GPUWarpgroupMappingAttr) { |
785 | return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); |
786 | }) |
787 | .Case([&](GPUWarpMappingAttr) { |
788 | return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); |
789 | }) |
790 | .Case([&](GPUThreadMappingAttr) { |
791 | return GpuThreadIdBuilder(ctx, useLinearMapping); |
792 | }) |
793 | .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { |
794 | llvm_unreachable("unknown mapping attribute"); |
795 | }); |
796 | return DiagnosedSilenceableFailure::success(); |
797 | } |
798 | |
799 | DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( |
800 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
801 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, |
802 | bool syncAfterDistribute) { |
803 | |
804 | { |
805 | // GPU-specific verifications. There is no better place to anchor |
806 | // those right now: the ForallOp is target-independent and the transform |
807 | // op does not apply to individual ForallOp. |
808 | DiagnosedSilenceableFailure diag = |
809 | verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); |
810 | if (!diag.succeeded()) |
811 | return diag; |
812 | } |
813 | |
814 | GpuIdBuilder gpuIdBuilder; |
815 | { |
816 | // Try to construct the id builder, if it fails, return. |
817 | DiagnosedSilenceableFailure diag = getThreadIdBuilder( |
818 | transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); |
819 | if (!diag.succeeded()) |
820 | return diag; |
821 | } |
822 | |
823 | Location loc = forallOp.getLoc(); |
824 | OpBuilder::InsertionGuard g(rewriter); |
825 | // Insert after to allow for syncthreads after `forall` is erased. |
826 | rewriter.setInsertionPointAfter(forallOp); |
827 | ForallRewriteResult rewriteResult; |
828 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
829 | rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder); |
830 | if (!diag.succeeded()) |
831 | return diag; |
832 | // Add a syncthreads if needed. TODO: warpsync |
833 | if (syncAfterDistribute) |
834 | rewriter.create<BarrierOp>(loc); |
835 | |
836 | return DiagnosedSilenceableFailure::success(); |
837 | } |
838 | |
839 | DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( |
840 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
841 | Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, |
842 | bool syncAfterDistribute) { |
843 | LDBG("Start mapNestedForallToThreadsImpl"); |
844 | if (blockDims.size() != 3) { |
845 | return definiteFailureHelper(transformOp, target, |
846 | message: "requires size-3 thread mapping"); |
847 | } |
848 | |
849 | // Create an early zero index value for replacements. |
850 | Location loc = target->getLoc(); |
851 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
852 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
853 | WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { |
854 | diag = mlir::transform::gpu::mapOneForallToThreadsImpl( |
855 | rewriter, transformOp, forallOp, blockDims, warpSize, |
856 | syncAfterDistribute); |
857 | if (diag.isDefiniteFailure()) |
858 | return WalkResult::interrupt(); |
859 | if (diag.succeeded()) |
860 | return WalkResult::skip(); |
861 | return WalkResult::advance(); |
862 | }); |
863 | if (walkResult.wasInterrupted()) |
864 | return diag; |
865 | |
866 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
867 | // Here, the result of mapping determines the available mapping sizes. |
868 | replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, |
869 | blockDims); |
870 | |
871 | return DiagnosedSilenceableFailure::success(); |
872 | } |
873 | |
874 | DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( |
875 | transform::TransformRewriter &rewriter, Operation *target, |
876 | ApplyToEachResultList &results, TransformState &state) { |
877 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
878 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
879 | |
880 | // Basic high-level verifications. |
881 | if (!gpuLaunch) |
882 | return emitSilenceableError() << "Given target is not a gpu.launch"; |
883 | |
884 | // Mapping to block ids. |
885 | SmallVector<int64_t> blockDims{getBlockDims()}; |
886 | DiagnosedSilenceableFailure diag = |
887 | checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, |
888 | blockDims[0], blockDims[1], blockDims[2]); |
889 | if (diag.isSilenceableFailure()) { |
890 | diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large"; |
891 | return diag; |
892 | } |
893 | |
894 | // Set the GPU launch configuration for the block dims early, this is not |
895 | // subject to IR inspection. |
896 | diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, |
897 | std::nullopt, std::nullopt, blockDims[0], blockDims[1], |
898 | blockDims[2]); |
899 | |
900 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
901 | diag = |
902 | mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, |
903 | getWarpSize(), getSyncAfterDistribute()); |
904 | |
905 | results.push_back(gpuLaunch.getOperation()); |
906 | return diag; |
907 | } |
908 | |
909 | //===----------------------------------------------------------------------===// |
910 | // Transform op registration |
911 | //===----------------------------------------------------------------------===// |
912 | |
913 | namespace { |
914 | /// Registers new ops and declares PDL as dependent dialect since the |
915 | /// additional ops are using PDL types for operands and results. |
916 | class GPUTransformDialectExtension |
917 | : public transform::TransformDialectExtension< |
918 | GPUTransformDialectExtension> { |
919 | public: |
920 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) |
921 | |
922 | GPUTransformDialectExtension() { |
923 | declareGeneratedDialect<GPUDialect>(); |
924 | declareGeneratedDialect<amdgpu::AMDGPUDialect>(); |
925 | declareGeneratedDialect<arith::ArithDialect>(); |
926 | declareGeneratedDialect<scf::SCFDialect>(); |
927 | registerTransformOps< |
928 | #define GET_OP_LIST |
929 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
930 | >(); |
931 | } |
932 | }; |
933 | } // namespace |
934 | |
935 | #define GET_OP_CLASSES |
936 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
937 | |
938 | void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
939 | registry.addExtensions<GPUTransformDialectExtension>(); |
940 | } |
941 |
Definitions
- gpuMmaUnrollOrder
- getSubgroupMmaNativeVectorSize
- MappingKind
- BlockMappingKind
- ThreadMappingKind
- definiteFailureHelper
- checkMappingAttributeTypes
- verifyGpuMapping
- ForallRewriteResult
- replaceUnitMappingIdsHelper
- rewriteOneForallCommonImpl
- mapForallToBlocksImpl
- findTopLevelForallOp
- checkMappingSpec
- getThreadIdBuilder
- mapOneForallToThreadsImpl
- mapNestedForallToThreadsImpl
- GPUTransformDialectExtension
- GPUTransformDialectExtension
Learn to use CMake with our Intro Training
Find out more