1 | //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" |
10 | |
11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
12 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/Dialect/GPU/TransformOps/Utils.h" |
19 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
20 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
23 | #include "mlir/Dialect/SCF/IR/SCF.h" |
24 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
25 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
28 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
29 | #include "mlir/IR/AffineExpr.h" |
30 | #include "mlir/IR/Builders.h" |
31 | #include "mlir/IR/BuiltinAttributes.h" |
32 | #include "mlir/IR/IRMapping.h" |
33 | #include "mlir/IR/MLIRContext.h" |
34 | #include "mlir/IR/OpDefinition.h" |
35 | #include "mlir/IR/Visitors.h" |
36 | #include "mlir/Support/LLVM.h" |
37 | #include "mlir/Transforms/DialectConversion.h" |
38 | #include "llvm/ADT/STLExtras.h" |
39 | #include "llvm/ADT/SmallVector.h" |
40 | #include "llvm/ADT/TypeSwitch.h" |
41 | #include "llvm/Support/Debug.h" |
42 | #include "llvm/Support/ErrorHandling.h" |
43 | #include <type_traits> |
44 | |
45 | using namespace mlir; |
46 | using namespace mlir::gpu; |
47 | using namespace mlir::transform; |
48 | using namespace mlir::transform::gpu; |
49 | |
50 | #define DEBUG_TYPE "gpu-transforms" |
51 | #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" |
52 | |
53 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
54 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
55 | #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // Apply...ConversionPatternsOp |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( |
62 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
63 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
64 | // NVVM uses alloca in the default address space to represent private |
65 | // memory allocations, so drop private annotations. NVVM uses address |
66 | // space 3 for shared memory. NVVM uses the default address space to |
67 | // represent global memory. |
68 | // Used in populateGpuToNVVMConversionPatternsso attaching here for now. |
69 | // TODO: We should have a single to_nvvm_type_converter. |
70 | populateGpuMemorySpaceAttributeConversions( |
71 | llvmTypeConverter, [](AddressSpace space) -> unsigned { |
72 | switch (space) { |
73 | case AddressSpace::Global: |
74 | return static_cast<unsigned>( |
75 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
76 | case AddressSpace::Workgroup: |
77 | return static_cast<unsigned>( |
78 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
79 | case AddressSpace::Private: |
80 | return 0; |
81 | } |
82 | llvm_unreachable("unknown address space enum value" ); |
83 | return 0; |
84 | }); |
85 | // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. |
86 | // TODO: We should have a single to_nvvm_type_converter. |
87 | llvmTypeConverter.addConversion( |
88 | [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); |
89 | populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns); |
90 | } |
91 | |
92 | LogicalResult |
93 | transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
94 | transform::TypeConverterBuilderOpInterface builder) { |
95 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
96 | return emitOpError("expected LLVMTypeConverter" ); |
97 | return success(); |
98 | } |
99 | |
100 | void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( |
101 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
102 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
103 | populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns); |
104 | } |
105 | |
106 | LogicalResult |
107 | transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( |
108 | transform::TypeConverterBuilderOpInterface builder) { |
109 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
110 | return emitOpError("expected LLVMTypeConverter" ); |
111 | return success(); |
112 | } |
113 | |
114 | void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
115 | populatePatterns(TypeConverter &typeConverter, |
116 | RewritePatternSet &patterns) { |
117 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
118 | populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns); |
119 | } |
120 | |
121 | LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: |
122 | verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { |
123 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
124 | return emitOpError("expected LLVMTypeConverter" ); |
125 | return success(); |
126 | } |
127 | |
128 | //===----------------------------------------------------------------------===// |
129 | // Apply...PatternsOp |
130 | //===----------------------------------------------------------------------===//s |
131 | |
132 | void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { |
133 | populateGpuRewritePatterns(patterns); |
134 | } |
135 | |
136 | //===----------------------------------------------------------------------===// |
137 | // ApplyUnrollVectorsSubgroupMmaOp |
138 | //===----------------------------------------------------------------------===// |
139 | |
140 | /// Pick an unrolling order that will allow tensorcore operation to reuse LHS |
141 | /// register. |
142 | static std::optional<SmallVector<int64_t>> |
143 | gpuMmaUnrollOrder(vector::ContractionOp contract) { |
144 | SmallVector<int64_t> order; |
145 | // First make reduction the outer dimensions. |
146 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
147 | if (vector::isReductionIterator(iter)) { |
148 | order.push_back(index); |
149 | } |
150 | } |
151 | |
152 | llvm::SmallDenseSet<int64_t> dims; |
153 | for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { |
154 | dims.insert(cast<AffineDimExpr>(expr).getPosition()); |
155 | } |
156 | // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. |
157 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
158 | if (vector::isParallelIterator(iter) && dims.count(index)) { |
159 | order.push_back(index); |
160 | } |
161 | } |
162 | // Then the remaining parallel loops. |
163 | for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { |
164 | if (vector::isParallelIterator(iter) && !dims.count(index)) { |
165 | order.push_back(index); |
166 | } |
167 | } |
168 | return order; |
169 | } |
170 | |
171 | /// Returns the target vector size for the target operation based on the native |
172 | /// vector size specified with `m`, `n`, and `k`. |
173 | static std::optional<SmallVector<int64_t>> |
174 | getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { |
175 | if (auto contract = dyn_cast<vector::ContractionOp>(op)) { |
176 | int64_t contractRank = contract.getIteratorTypes().size(); |
177 | if (contractRank < 3) |
178 | return std::nullopt; |
179 | SmallVector<int64_t> nativeSize(contractRank - 3, 1); |
180 | nativeSize.append(IL: {m, n, k}); |
181 | return nativeSize; |
182 | } |
183 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
184 | int64_t writeRank = writeOp.getVectorType().getRank(); |
185 | if (writeRank < 2) |
186 | return std::nullopt; |
187 | SmallVector<int64_t> nativeSize(writeRank - 2, 1); |
188 | nativeSize.append(IL: {m, n}); |
189 | return nativeSize; |
190 | } |
191 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { |
192 | // Transfer read ops may need different shapes based on how they are being |
193 | // used. For simplicity just match the shape used by the extract strided op. |
194 | VectorType sliceType; |
195 | for (Operation *users : op->getUsers()) { |
196 | auto = dyn_cast<vector::ExtractStridedSliceOp>(users); |
197 | if (!extract) |
198 | return std::nullopt; |
199 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
200 | if (sliceType && sliceType != vecType) |
201 | return std::nullopt; |
202 | sliceType = vecType; |
203 | } |
204 | return llvm::to_vector(sliceType.getShape()); |
205 | } |
206 | if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { |
207 | if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { |
208 | // TODO: The condition for unrolling elementwise should be restricted |
209 | // only to operations that need unrolling (connected to the contract). |
210 | if (vecType.getRank() < 2) |
211 | return std::nullopt; |
212 | |
213 | // First check whether there is a slice to infer the shape from. This is |
214 | // required for cases where the accumulator type differs from the input |
215 | // types, in which case we will see an `arith.ext_` between the contract |
216 | // and transfer_read which needs to be unrolled. |
217 | VectorType sliceType; |
218 | for (Operation *users : op->getUsers()) { |
219 | auto = dyn_cast<vector::ExtractStridedSliceOp>(users); |
220 | if (!extract) |
221 | return std::nullopt; |
222 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
223 | if (sliceType && sliceType != vecType) |
224 | return std::nullopt; |
225 | sliceType = vecType; |
226 | } |
227 | if (sliceType) |
228 | return llvm::to_vector(sliceType.getShape()); |
229 | |
230 | // Else unroll for trailing elementwise. |
231 | SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); |
232 | // Map elementwise ops to the output shape. |
233 | nativeSize.append(IL: {m, n}); |
234 | return nativeSize; |
235 | } |
236 | } |
237 | return std::nullopt; |
238 | } |
239 | |
240 | void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( |
241 | RewritePatternSet &patterns) { |
242 | auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
243 | auto contract = dyn_cast<vector::ContractionOp>(op); |
244 | if (!contract) |
245 | return std::nullopt; |
246 | return gpuMmaUnrollOrder(contract); |
247 | }; |
248 | |
249 | int64_t m = getM(); |
250 | int64_t n = getN(); |
251 | int64_t k = getK(); |
252 | auto nativeShapeFn = |
253 | [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { |
254 | return getSubgroupMmaNativeVectorSize(op, m, n, k); |
255 | }; |
256 | vector::populateVectorUnrollPatterns( |
257 | patterns, vector::UnrollVectorOptions() |
258 | .setNativeShapeFn(nativeShapeFn) |
259 | .setUnrollTraversalOrderFn(unrollOrder)); |
260 | } |
261 | |
262 | //===----------------------------------------------------------------------===// |
263 | // EliminateBarriersOp |
264 | //===----------------------------------------------------------------------===// |
265 | |
266 | void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { |
267 | populateGpuEliminateBarriersPatterns(patterns); |
268 | } |
269 | |
270 | //===----------------------------------------------------------------------===// |
271 | // Block and thread mapping utilities. |
272 | //===----------------------------------------------------------------------===// |
273 | |
274 | namespace { |
275 | /// Local types used for mapping verification. |
276 | struct MappingKind {}; |
277 | struct BlockMappingKind : MappingKind {}; |
278 | struct ThreadMappingKind : MappingKind {}; |
279 | } // namespace |
280 | |
281 | static DiagnosedSilenceableFailure |
282 | definiteFailureHelper(std::optional<TransformOpInterface> transformOp, |
283 | Operation *target, const Twine &message) { |
284 | if (transformOp.has_value()) |
285 | return transformOp->emitDefiniteFailure() << message; |
286 | return emitDefiniteFailure(op: target, message); |
287 | } |
288 | |
289 | /// Check if given mapping attributes are one of the desired attributes |
290 | template <typename MappingKindType> |
291 | static DiagnosedSilenceableFailure |
292 | checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, |
293 | scf::ForallOp forallOp) { |
294 | if (!forallOp.getMapping().has_value()) { |
295 | return definiteFailureHelper(transformOp, forallOp, |
296 | "scf.forall op requires a mapping attribute" ); |
297 | } |
298 | |
299 | bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), |
300 | llvm::IsaPred<GPUBlockMappingAttr>); |
301 | bool hasWarpgroupMapping = llvm::any_of( |
302 | forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); |
303 | bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), |
304 | llvm::IsaPred<GPUWarpMappingAttr>); |
305 | bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), |
306 | llvm::IsaPred<GPUThreadMappingAttr>); |
307 | int64_t countMappingTypes = 0; |
308 | countMappingTypes += hasBlockMapping ? 1 : 0; |
309 | countMappingTypes += hasWarpgroupMapping ? 1 : 0; |
310 | countMappingTypes += hasWarpMapping ? 1 : 0; |
311 | countMappingTypes += hasThreadMapping ? 1 : 0; |
312 | if (countMappingTypes > 1) { |
313 | return definiteFailureHelper( |
314 | transformOp, forallOp, |
315 | "cannot mix different mapping types, use nesting" ); |
316 | } |
317 | if (std::is_same<MappingKindType, BlockMappingKind>::value && |
318 | !hasBlockMapping) { |
319 | return definiteFailureHelper( |
320 | transformOp, forallOp, |
321 | "scf.forall op requires a mapping attribute of kind 'block'" ); |
322 | } |
323 | if (std::is_same<MappingKindType, ThreadMappingKind>::value && |
324 | !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) { |
325 | return definiteFailureHelper(transformOp, forallOp, |
326 | "scf.forall op requires a mapping attribute " |
327 | "of kind 'thread' or 'warp'" ); |
328 | } |
329 | |
330 | DenseSet<Attribute> seen; |
331 | for (Attribute map : forallOp.getMapping()->getValue()) { |
332 | if (seen.contains(map)) { |
333 | return definiteFailureHelper( |
334 | transformOp, forallOp, |
335 | "duplicate attribute, cannot map different loops " |
336 | "to the same mapping id" ); |
337 | } |
338 | seen.insert(map); |
339 | } |
340 | |
341 | auto isLinear = [](Attribute a) { |
342 | return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); |
343 | }; |
344 | if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && |
345 | !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { |
346 | return definiteFailureHelper( |
347 | transformOp, forallOp, |
348 | "cannot mix linear and non-linear mapping modes" ); |
349 | } |
350 | |
351 | return DiagnosedSilenceableFailure::success(); |
352 | } |
353 | |
354 | template <typename MappingKindType> |
355 | static DiagnosedSilenceableFailure |
356 | verifyGpuMapping(std::optional<TransformOpInterface> transformOp, |
357 | scf::ForallOp forallOp) { |
358 | // Check the types of the mapping attributes match. |
359 | DiagnosedSilenceableFailure typeRes = |
360 | checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); |
361 | if (!typeRes.succeeded()) |
362 | return typeRes; |
363 | |
364 | // Perform other non-types verifications. |
365 | if (!forallOp.isNormalized()) |
366 | return definiteFailureHelper(transformOp, forallOp, |
367 | "unsupported non-normalized loops" ); |
368 | if (forallOp.getNumResults() > 0) |
369 | return definiteFailureHelper(transformOp, forallOp, |
370 | "only bufferized scf.forall can be mapped" ); |
371 | bool useLinearMapping = cast<DeviceMappingAttrInterface>( |
372 | forallOp.getMapping()->getValue().front()) |
373 | .isLinearMapping(); |
374 | // TODO: This would be more natural with support for Optional<EnumParameter> |
375 | // in GPUDeviceMappingAttr. |
376 | int64_t maxNumMappingsSupported = |
377 | useLinearMapping ? (getMaxEnumValForMappingId() - |
378 | static_cast<uint64_t>(MappingId::DimZ)) |
379 | : 3; |
380 | if (forallOp.getRank() > maxNumMappingsSupported) { |
381 | return definiteFailureHelper(transformOp, forallOp, |
382 | "scf.forall with rank > " ) |
383 | << maxNumMappingsSupported |
384 | << " does not lower for the specified mapping attribute type" ; |
385 | } |
386 | auto numParallelIterations = |
387 | getConstantIntValues(forallOp.getMixedUpperBound()); |
388 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
389 | return definiteFailureHelper( |
390 | transformOp, forallOp, |
391 | "requires statically sized, normalized forall op" ); |
392 | } |
393 | return DiagnosedSilenceableFailure::success(); |
394 | } |
395 | |
396 | /// Struct to return the result of the rewrite of a forall operation. |
397 | struct ForallRewriteResult { |
398 | SmallVector<int64_t> mappingSizes; |
399 | SmallVector<Value> mappingIds; |
400 | }; |
401 | |
402 | /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. |
403 | template <typename OpTy, typename OperationOrBlock> |
404 | static void |
405 | replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, |
406 | OperationOrBlock *parent, Value replacement, |
407 | ArrayRef<int64_t> availableMappingSizes) { |
408 | parent->walk([&](OpTy idOp) { |
409 | if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) |
410 | rewriter.replaceAllUsesWith(idOp.getResult(), replacement); |
411 | }); |
412 | } |
413 | |
414 | static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( |
415 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
416 | scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, |
417 | ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { |
418 | LDBG("--start rewriteOneForallCommonImpl" ); |
419 | |
420 | // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. |
421 | auto numParallelIterations = |
422 | getConstantIntValues(forallOp.getMixedUpperBound()); |
423 | assert(forallOp.isNormalized() && numParallelIterations.has_value() && |
424 | "requires statically sized, normalized forall op" ); |
425 | SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); |
426 | SetVector<Attribute> forallMappingAttrs; |
427 | forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(), |
428 | forallOp.getMapping()->getValue().end()); |
429 | auto comparator = [](Attribute a, Attribute b) -> bool { |
430 | return cast<DeviceMappingAttrInterface>(a).getMappingId() < |
431 | cast<DeviceMappingAttrInterface>(b).getMappingId(); |
432 | }; |
433 | |
434 | // Step 1.b. In the linear case, compute the max mapping to avoid needlessly |
435 | // mapping all dimensions. In the 3-D mapping case we need to map all |
436 | // dimensions. |
437 | DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( |
438 | *llvm::max_element(forallMappingAttrs, comparator)); |
439 | DeviceMappingAttrInterface maxLinearMapping; |
440 | if (maxMapping.isLinearMapping()) |
441 | maxLinearMapping = maxMapping; |
442 | for (auto attr : gpuIdBuilder.mappingAttributes) { |
443 | // If attr overflows, just skip. |
444 | if (maxLinearMapping && comparator(maxLinearMapping, attr)) |
445 | continue; |
446 | // Try to insert. If element was already present, just continue. |
447 | if (!forallMappingAttrs.insert(attr)) |
448 | continue; |
449 | // Otherwise, we have a new insertion without a size -> use size 1. |
450 | tmpMappingSizes.push_back(1); |
451 | } |
452 | LLVM_DEBUG( |
453 | llvm::interleaveComma( |
454 | tmpMappingSizes, |
455 | DBGS() << "----tmpMappingSizes extracted from scf.forall op: " ); |
456 | llvm::dbgs() << "\n" ); |
457 | |
458 | // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. |
459 | SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( |
460 | keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator); |
461 | LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, |
462 | DBGS() << "----forallMappingSizes: " ); |
463 | llvm::dbgs() << "\n" ; llvm::interleaveComma( |
464 | forallMappingAttrs, DBGS() << "----forallMappingAttrs: " ); |
465 | llvm::dbgs() << "\n" ); |
466 | |
467 | // Step 3. Generate the mappingIdOps using the provided generator. |
468 | Location loc = forallOp.getLoc(); |
469 | OpBuilder::InsertionGuard guard(rewriter); |
470 | rewriter.setInsertionPoint(forallOp); |
471 | SmallVector<int64_t> originalBasis(availableMappingSizes); |
472 | bool originalBasisWasProvided = !originalBasis.empty(); |
473 | if (!originalBasisWasProvided) { |
474 | originalBasis = forallMappingSizes; |
475 | while (originalBasis.size() < 3) |
476 | originalBasis.push_back(Elt: 1); |
477 | } |
478 | |
479 | IdBuilderResult builderResult = |
480 | gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); |
481 | |
482 | // Step 4. Map the induction variables to the mappingIdOps, this may involve |
483 | // a permutation. |
484 | SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; |
485 | IRMapping bvm; |
486 | for (auto [iv, dim] : llvm::zip_equal( |
487 | forallOp.getInductionVars(), |
488 | forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { |
489 | auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); |
490 | Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; |
491 | bvm.map(iv, peIdOp); |
492 | } |
493 | |
494 | // Step 5. If the originalBasis is already known, create conditionals to |
495 | // predicate the region. Otherwise, the current forall determines the |
496 | // originalBasis and no predication occurs. |
497 | Value predicate; |
498 | if (originalBasisWasProvided) { |
499 | SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes; |
500 | SmallVector<int64_t> availableMappingSizes = |
501 | builderResult.availableMappingSizes; |
502 | SmallVector<Value> activeIdOps = builderResult.activeIdOps; |
503 | // clang-format off |
504 | LLVM_DEBUG( |
505 | llvm::interleaveComma( |
506 | activeMappingSizes, DBGS() << "----activeMappingSizes: " ); |
507 | llvm::dbgs() << "\n" ; |
508 | llvm::interleaveComma( |
509 | availableMappingSizes, DBGS() << "----availableMappingSizes: " ); |
510 | llvm::dbgs() << "\n" ; |
511 | llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: " ); |
512 | llvm::dbgs() << "\n" ); |
513 | // clang-format on |
514 | for (auto [activeId, activeMappingSize, availableMappingSize] : |
515 | llvm::zip_equal(activeIdOps, activeMappingSizes, |
516 | availableMappingSizes)) { |
517 | if (activeMappingSize > availableMappingSize) { |
518 | return definiteFailureHelper( |
519 | transformOp, forallOp, |
520 | "Trying to map to fewer GPU threads than loop iterations but " |
521 | "overprovisioning is not yet supported. " |
522 | "Try additional tiling of the before mapping or map to more " |
523 | "threads." ); |
524 | } |
525 | if (activeMappingSize == availableMappingSize) |
526 | continue; |
527 | Value idx = |
528 | rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize); |
529 | Value tmpPredicate = rewriter.create<arith::CmpIOp>( |
530 | loc, arith::CmpIPredicate::ult, activeId, idx); |
531 | LDBG("----predicate: " << tmpPredicate); |
532 | predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, |
533 | tmpPredicate) |
534 | : tmpPredicate; |
535 | } |
536 | } |
537 | |
538 | // Step 6. Move the body of forallOp. |
539 | // Erase the terminator first, it will not be used. |
540 | rewriter.eraseOp(op: forallOp.getTerminator()); |
541 | Block *targetBlock; |
542 | Block::iterator insertionPoint; |
543 | if (predicate) { |
544 | // Step 6.a. If predicated, move at the beginning. |
545 | auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, |
546 | /*withElseRegion=*/false); |
547 | targetBlock = ifOp.thenBlock(); |
548 | insertionPoint = ifOp.thenBlock()->begin(); |
549 | } else { |
550 | // Step 6.b. Otherwise, move inline just at the rewriter insertion |
551 | // point. |
552 | targetBlock = forallOp->getBlock(); |
553 | insertionPoint = rewriter.getInsertionPoint(); |
554 | } |
555 | Block &sourceBlock = forallOp.getRegion().front(); |
556 | targetBlock->getOperations().splice(where: insertionPoint, |
557 | L2&: sourceBlock.getOperations()); |
558 | |
559 | // Step 7. RAUW indices. |
560 | for (Value loopIndex : forallOp.getInductionVars()) { |
561 | Value threadIdx = bvm.lookup(loopIndex); |
562 | rewriter.replaceAllUsesWith(loopIndex, threadIdx); |
563 | } |
564 | |
565 | // Step 8. Erase old op. |
566 | rewriter.eraseOp(op: forallOp); |
567 | |
568 | LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, |
569 | DBGS() << "----result forallMappingSizes: " ); |
570 | llvm::dbgs() << "\n" ; llvm::interleaveComma( |
571 | mappingIdOps, DBGS() << "----result mappingIdOps: " ); |
572 | llvm::dbgs() << "\n" ); |
573 | |
574 | result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps}; |
575 | return DiagnosedSilenceableFailure::success(); |
576 | } |
577 | |
578 | //===----------------------------------------------------------------------===// |
579 | // MapForallToBlocks |
580 | //===----------------------------------------------------------------------===// |
581 | |
582 | DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( |
583 | RewriterBase &rewriter, TransformOpInterface transformOp, |
584 | scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, |
585 | const GpuIdBuilder &gpuIdBuilder) { |
586 | LDBG("Start mapForallToBlocksImpl" ); |
587 | |
588 | { |
589 | // GPU-specific verifications. There is no better place to anchor |
590 | // those right now: the ForallOp is target-independent and the transform |
591 | // op does not apply to individual ForallOp. |
592 | DiagnosedSilenceableFailure diag = |
593 | verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); |
594 | if (!diag.succeeded()) |
595 | return diag; |
596 | } |
597 | |
598 | Location loc = forallOp.getLoc(); |
599 | Block *parentBlock = forallOp->getBlock(); |
600 | Value zero; |
601 | { |
602 | // Create an early zero index value for replacements and immediately reset |
603 | // the insertion point. |
604 | OpBuilder::InsertionGuard guard(rewriter); |
605 | rewriter.setInsertionPointToStart(parentBlock); |
606 | zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
607 | } |
608 | |
609 | ForallRewriteResult rewriteResult; |
610 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
611 | rewriter, transformOp, forallOp, |
612 | /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder); |
613 | |
614 | // Return if anything goes wrong, use silenceable failure as a match |
615 | // failure. |
616 | if (!diag.succeeded()) |
617 | return diag; |
618 | |
619 | // If gridDims was not provided already, set it from the return. |
620 | if (gridDims.empty()) { |
621 | gridDims = rewriteResult.mappingSizes; |
622 | while (gridDims.size() < 3) |
623 | gridDims.push_back(Elt: 1); |
624 | } |
625 | assert(gridDims.size() == 3 && "Need 3-D gridDims" ); |
626 | |
627 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
628 | // Here, the result of mapping determines the available mapping sizes. |
629 | replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, |
630 | rewriteResult.mappingSizes); |
631 | |
632 | return DiagnosedSilenceableFailure::success(); |
633 | } |
634 | |
635 | DiagnosedSilenceableFailure |
636 | mlir::transform::gpu::findTopLevelForallOp(Operation *target, |
637 | scf::ForallOp &topLevelForallOp, |
638 | TransformOpInterface transformOp) { |
639 | auto walkResult = target->walk([&](scf::ForallOp forallOp) { |
640 | if (forallOp->getParentOfType<scf::ForallOp>()) |
641 | return WalkResult::advance(); |
642 | if (topLevelForallOp) |
643 | // TODO: Handle multiple forall if they are independent. |
644 | return WalkResult::interrupt(); |
645 | topLevelForallOp = forallOp; |
646 | return WalkResult::advance(); |
647 | }); |
648 | |
649 | if (walkResult.wasInterrupted() || !topLevelForallOp) |
650 | return transformOp.emitSilenceableError() |
651 | << "could not find a unique topLevel scf.forall" ; |
652 | return DiagnosedSilenceableFailure::success(); |
653 | } |
654 | |
655 | DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( |
656 | transform::TransformRewriter &rewriter, Operation *target, |
657 | ApplyToEachResultList &results, transform::TransformState &state) { |
658 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
659 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
660 | |
661 | if (!getGenerateGpuLaunch() && !gpuLaunch) { |
662 | DiagnosedSilenceableFailure diag = |
663 | emitSilenceableError() |
664 | << "Given target is not gpu.launch, set `generate_gpu_launch` " |
665 | "attribute" ; |
666 | diag.attachNote(target->getLoc()) << "when applied to this payload op" ; |
667 | return diag; |
668 | } |
669 | |
670 | scf::ForallOp topLevelForallOp; |
671 | DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( |
672 | target, topLevelForallOp, transformOp); |
673 | if (!diag.succeeded()) { |
674 | diag.attachNote(target->getLoc()) << "when applied to this payload op" ; |
675 | return diag; |
676 | } |
677 | assert(topLevelForallOp && "expect an scf.forall" ); |
678 | |
679 | SmallVector<int64_t> gridDims{getGridDims()}; |
680 | if (!getGenerateGpuLaunch() && gridDims.size() != 3) |
681 | return transformOp.emitDefiniteFailure("transform require size-3 mapping" ); |
682 | |
683 | OpBuilder::InsertionGuard guard(rewriter); |
684 | rewriter.setInsertionPoint(topLevelForallOp); |
685 | |
686 | // Generate gpu launch here and move the forall inside |
687 | if (getGenerateGpuLaunch()) { |
688 | DiagnosedSilenceableFailure diag = |
689 | createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); |
690 | if (!diag.succeeded()) |
691 | return diag; |
692 | |
693 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
694 | Operation *newForallOp = rewriter.clone(*topLevelForallOp); |
695 | rewriter.eraseOp(topLevelForallOp); |
696 | topLevelForallOp = cast<scf::ForallOp>(newForallOp); |
697 | } |
698 | |
699 | // The BlockIdBuilder adapts to whatever is thrown at it. |
700 | bool useLinearMapping = false; |
701 | if (topLevelForallOp.getMapping()) { |
702 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
703 | topLevelForallOp.getMapping()->getValue().front()); |
704 | useLinearMapping = mappingAttr.isLinearMapping(); |
705 | } |
706 | GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); |
707 | |
708 | diag = mlir::transform::gpu::mapForallToBlocksImpl( |
709 | rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); |
710 | if (!diag.succeeded()) |
711 | return diag; |
712 | |
713 | // Set the GPU launch configuration for the grid dims late, this is |
714 | // subject to IR inspection. |
715 | diag = alterGpuLaunch(rewriter, gpuLaunch, |
716 | cast<TransformOpInterface>(getOperation()), gridDims[0], |
717 | gridDims[1], gridDims[2]); |
718 | |
719 | results.push_back(gpuLaunch); |
720 | return diag; |
721 | } |
722 | |
723 | LogicalResult transform::MapForallToBlocks::verify() { |
724 | if (!getGridDims().empty() && getGridDims().size() != 3) { |
725 | return emitOpError() << "transform requires empty or size-3 grid_dims" ; |
726 | } |
727 | return success(); |
728 | } |
729 | |
730 | //===----------------------------------------------------------------------===// |
731 | // MapNestedForallToThreads |
732 | //===----------------------------------------------------------------------===// |
733 | |
734 | static DiagnosedSilenceableFailure checkMappingSpec( |
735 | std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, |
736 | ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, |
737 | int factor, bool useLinearMapping = false) { |
738 | if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { |
739 | auto diag = definiteFailureHelper( |
740 | transformOp, forallOp, |
741 | Twine("3-D mapping: size of threadIdx.x must be a multiple of " ) + |
742 | std::to_string(val: factor)); |
743 | return diag; |
744 | } |
745 | if (computeProduct(basis: numParallelIterations) * factor > |
746 | computeProduct(basis: blockOrGridSizes)) { |
747 | auto diag = definiteFailureHelper( |
748 | transformOp, forallOp, |
749 | Twine("the number of required parallel resources (blocks or " |
750 | "threads) " ) + |
751 | std::to_string(val: computeProduct(basis: numParallelIterations) * factor) + |
752 | std::string(" overflows the number of available resources " ) + |
753 | std::to_string(val: computeProduct(basis: blockOrGridSizes))); |
754 | return diag; |
755 | } |
756 | return DiagnosedSilenceableFailure::success(); |
757 | } |
758 | |
759 | static DiagnosedSilenceableFailure |
760 | getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, |
761 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, |
762 | int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { |
763 | auto mappingAttr = cast<DeviceMappingAttrInterface>( |
764 | forallOp.getMapping()->getValue().front()); |
765 | bool useLinearMapping = mappingAttr.isLinearMapping(); |
766 | |
767 | // Sanity checks that may result in runtime verification errors. |
768 | auto numParallelIterations = |
769 | getConstantIntValues((forallOp.getMixedUpperBound())); |
770 | if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { |
771 | return definiteFailureHelper( |
772 | transformOp, forallOp, |
773 | "requires statically sized, normalized forall op" ); |
774 | } |
775 | int64_t factor = 1; |
776 | if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) { |
777 | factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; |
778 | } else if (isa<GPUWarpMappingAttr>(mappingAttr)) { |
779 | factor = warpSize; |
780 | } |
781 | DiagnosedSilenceableFailure diag = |
782 | checkMappingSpec(transformOp, forallOp, numParallelIterations.value(), |
783 | blockSizes, factor, useLinearMapping); |
784 | if (!diag.succeeded()) |
785 | return diag; |
786 | |
787 | // Start mapping. |
788 | MLIRContext *ctx = forallOp.getContext(); |
789 | gpuIdBuilder = |
790 | TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) |
791 | .Case([&](GPUWarpgroupMappingAttr) { |
792 | return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); |
793 | }) |
794 | .Case([&](GPUWarpMappingAttr) { |
795 | return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); |
796 | }) |
797 | .Case([&](GPUThreadMappingAttr) { |
798 | return GpuThreadIdBuilder(ctx, useLinearMapping); |
799 | }) |
800 | .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { |
801 | llvm_unreachable("unknown mapping attribute" ); |
802 | }); |
803 | return DiagnosedSilenceableFailure::success(); |
804 | } |
805 | |
806 | DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( |
807 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
808 | scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, |
809 | bool syncAfterDistribute) { |
810 | |
811 | { |
812 | // GPU-specific verifications. There is no better place to anchor |
813 | // those right now: the ForallOp is target-independent and the transform |
814 | // op does not apply to individual ForallOp. |
815 | DiagnosedSilenceableFailure diag = |
816 | verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); |
817 | if (!diag.succeeded()) |
818 | return diag; |
819 | } |
820 | |
821 | GpuIdBuilder gpuIdBuilder; |
822 | { |
823 | // Try to construct the id builder, if it fails, return. |
824 | DiagnosedSilenceableFailure diag = getThreadIdBuilder( |
825 | transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); |
826 | if (!diag.succeeded()) |
827 | return diag; |
828 | } |
829 | |
830 | Location loc = forallOp.getLoc(); |
831 | OpBuilder::InsertionGuard g(rewriter); |
832 | // Insert after to allow for syncthreads after `forall` is erased. |
833 | rewriter.setInsertionPointAfter(forallOp); |
834 | ForallRewriteResult rewriteResult; |
835 | DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( |
836 | rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder); |
837 | if (!diag.succeeded()) |
838 | return diag; |
839 | // Add a syncthreads if needed. TODO: warpsync |
840 | if (syncAfterDistribute) |
841 | rewriter.create<BarrierOp>(loc); |
842 | |
843 | return DiagnosedSilenceableFailure::success(); |
844 | } |
845 | |
846 | DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( |
847 | RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, |
848 | Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, |
849 | bool syncAfterDistribute) { |
850 | LDBG("Start mapNestedForallToThreadsImpl" ); |
851 | if (blockDims.size() != 3) { |
852 | return definiteFailureHelper(transformOp, target, |
853 | message: "requires size-3 thread mapping" ); |
854 | } |
855 | |
856 | // Create an early zero index value for replacements. |
857 | Location loc = target->getLoc(); |
858 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
859 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
860 | WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { |
861 | diag = mlir::transform::gpu::mapOneForallToThreadsImpl( |
862 | rewriter, transformOp, forallOp, blockDims, warpSize, |
863 | syncAfterDistribute); |
864 | if (diag.isDefiniteFailure()) |
865 | return WalkResult::interrupt(); |
866 | if (diag.succeeded()) |
867 | return WalkResult::skip(); |
868 | return WalkResult::advance(); |
869 | }); |
870 | if (walkResult.wasInterrupted()) |
871 | return diag; |
872 | |
873 | // Replace ids of dimensions known to be 1 by 0 to simplify the IR. |
874 | // Here, the result of mapping determines the available mapping sizes. |
875 | replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, |
876 | blockDims); |
877 | |
878 | return DiagnosedSilenceableFailure::success(); |
879 | } |
880 | |
881 | DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( |
882 | transform::TransformRewriter &rewriter, Operation *target, |
883 | ApplyToEachResultList &results, TransformState &state) { |
884 | LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); |
885 | auto transformOp = cast<TransformOpInterface>(getOperation()); |
886 | |
887 | // Basic high-level verifications. |
888 | if (!gpuLaunch) |
889 | return emitSilenceableError() << "Given target is not a gpu.launch" ; |
890 | |
891 | // Mapping to block ids. |
892 | SmallVector<int64_t> blockDims{getBlockDims()}; |
893 | DiagnosedSilenceableFailure diag = |
894 | checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, |
895 | blockDims[0], blockDims[1], blockDims[2]); |
896 | if (diag.isSilenceableFailure()) { |
897 | diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large" ; |
898 | return diag; |
899 | } |
900 | |
901 | // Set the GPU launch configuration for the block dims early, this is not |
902 | // subject to IR inspection. |
903 | diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, |
904 | std::nullopt, std::nullopt, blockDims[0], blockDims[1], |
905 | blockDims[2]); |
906 | |
907 | rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); |
908 | diag = |
909 | mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, |
910 | getWarpSize(), getSyncAfterDistribute()); |
911 | |
912 | results.push_back(gpuLaunch.getOperation()); |
913 | return diag; |
914 | } |
915 | |
916 | //===----------------------------------------------------------------------===// |
917 | // Transform op registration |
918 | //===----------------------------------------------------------------------===// |
919 | |
920 | namespace { |
921 | /// Registers new ops and declares PDL as dependent dialect since the |
922 | /// additional ops are using PDL types for operands and results. |
923 | class GPUTransformDialectExtension |
924 | : public transform::TransformDialectExtension< |
925 | GPUTransformDialectExtension> { |
926 | public: |
927 | GPUTransformDialectExtension() { |
928 | declareGeneratedDialect<scf::SCFDialect>(); |
929 | declareGeneratedDialect<arith::ArithDialect>(); |
930 | declareGeneratedDialect<GPUDialect>(); |
931 | registerTransformOps< |
932 | #define GET_OP_LIST |
933 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
934 | >(); |
935 | } |
936 | }; |
937 | } // namespace |
938 | |
939 | #define GET_OP_CLASSES |
940 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" |
941 | |
942 | void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
943 | registry.addExtensions<GPUTransformDialectExtension>(); |
944 | } |
945 | |