1//===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
14#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/GPU/TransformOps/Utils.h"
19#include "mlir/Dialect/GPU/Transforms/Passes.h"
20#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
21#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/Transform/IR/TransformDialect.h"
26#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
27#include "mlir/Dialect/Utils/IndexingUtils.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/Builders.h"
32#include "mlir/IR/BuiltinAttributes.h"
33#include "mlir/IR/IRMapping.h"
34#include "mlir/IR/MLIRContext.h"
35#include "mlir/IR/OpDefinition.h"
36#include "mlir/IR/Visitors.h"
37#include "mlir/Support/LLVM.h"
38#include "mlir/Transforms/DialectConversion.h"
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/Debug.h"
43#include "llvm/Support/ErrorHandling.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/LogicalResult.h"
46#include <type_traits>
47
48using namespace mlir;
49using namespace mlir::gpu;
50using namespace mlir::transform;
51using namespace mlir::transform::gpu;
52
53#define DEBUG_TYPE "gpu-transforms"
54#define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
55
56#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
57#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
58#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
59
60//===----------------------------------------------------------------------===//
61// Apply...ConversionPatternsOp
62//===----------------------------------------------------------------------===//
63
64void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
65 TypeConverter &typeConverter, RewritePatternSet &patterns) {
66 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
67 // NVVM uses alloca in the default address space to represent private
68 // memory allocations, so drop private annotations. NVVM uses address
69 // space 3 for shared memory. NVVM uses the default address space to
70 // represent global memory.
71 // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
72 // TODO: We should have a single to_nvvm_type_converter.
73 populateGpuMemorySpaceAttributeConversions(
74 typeConverter&: llvmTypeConverter, mapping: [](AddressSpace space) -> unsigned {
75 switch (space) {
76 case AddressSpace::Global:
77 return static_cast<unsigned>(
78 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
79 case AddressSpace::Workgroup:
80 return static_cast<unsigned>(
81 NVVM::NVVMMemorySpace::kSharedMemorySpace);
82 case AddressSpace::Private:
83 return 0;
84 }
85 llvm_unreachable("unknown address space enum value");
86 return 0;
87 });
88 // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
89 // TODO: We should have a single to_nvvm_type_converter.
90 llvmTypeConverter.addConversion(
91 callback: [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
92 // Set higher benefit, so patterns will run before generic LLVM lowering.
93 populateGpuToNVVMConversionPatterns(converter: llvmTypeConverter, patterns,
94 benefit: getBenefit());
95}
96
97LogicalResult
98transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
99 transform::TypeConverterBuilderOpInterface builder) {
100 if (builder.getTypeConverterType() != "LLVMTypeConverter")
101 return emitOpError(message: "expected LLVMTypeConverter");
102 return success();
103}
104
105void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
106 TypeConverter &typeConverter, RewritePatternSet &patterns) {
107 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
108 populateGpuWMMAToNVVMConversionPatterns(converter: llvmTypeConverter, patterns);
109}
110
111LogicalResult
112transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
113 transform::TypeConverterBuilderOpInterface builder) {
114 if (builder.getTypeConverterType() != "LLVMTypeConverter")
115 return emitOpError(message: "expected LLVMTypeConverter");
116 return success();
117}
118
119void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
120 populatePatterns(TypeConverter &typeConverter,
121 RewritePatternSet &patterns) {
122 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
123 populateGpuSubgroupReduceOpLoweringPattern(converter: llvmTypeConverter, patterns);
124}
125
126LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
127 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
128 if (builder.getTypeConverterType() != "LLVMTypeConverter")
129 return emitOpError(message: "expected LLVMTypeConverter");
130 return success();
131}
132
133void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
134 TypeConverter &typeConverter, RewritePatternSet &patterns) {
135 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
136 populateGpuMemorySpaceAttributeConversions(
137 typeConverter&: llvmTypeConverter, mapping: [](AddressSpace space) {
138 switch (space) {
139 case AddressSpace::Global:
140 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
141 case AddressSpace::Workgroup:
142 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
143 case AddressSpace::Private:
144 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
145 }
146 llvm_unreachable("unknown address space enum value");
147 });
148 FailureOr<amdgpu::Chipset> maybeChipset =
149 amdgpu::Chipset::parse(name: getChipset());
150 assert(llvm::succeeded(maybeChipset) && "expected valid chipset");
151 populateGpuToROCDLConversionPatterns(
152 converter: llvmTypeConverter, patterns, runtime: mlir::gpu::amd::Runtime::HIP, chipset: *maybeChipset);
153}
154
155LogicalResult
156transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
157 transform::TypeConverterBuilderOpInterface builder) {
158 FailureOr<amdgpu::Chipset> maybeChipset =
159 amdgpu::Chipset::parse(name: getChipset());
160 if (failed(Result: maybeChipset)) {
161 return emitOpError(message: "Invalid chipset name: " + getChipset());
162 }
163 if (builder.getTypeConverterType() != "LLVMTypeConverter")
164 return emitOpError(message: "expected LLVMTypeConverter");
165 return success();
166}
167
168//===----------------------------------------------------------------------===//
169// Apply...PatternsOp
170//===----------------------------------------------------------------------===//s
171
172void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
173 populateGpuRewritePatterns(patterns);
174}
175
176void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
177 RewritePatternSet &patterns) {
178 populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
179}
180
181//===----------------------------------------------------------------------===//
182// ApplyUnrollVectorsSubgroupMmaOp
183//===----------------------------------------------------------------------===//
184
185/// Pick an unrolling order that will allow tensorcore operation to reuse LHS
186/// register.
187static std::optional<SmallVector<int64_t>>
188gpuMmaUnrollOrder(vector::ContractionOp contract) {
189 SmallVector<int64_t> order;
190 // First make reduction the outer dimensions.
191 for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) {
192 if (vector::isReductionIterator(attr: iter)) {
193 order.push_back(Elt: index);
194 }
195 }
196
197 llvm::SmallDenseSet<int64_t> dims;
198 for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
199 dims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition());
200 }
201 // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
202 for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) {
203 if (vector::isParallelIterator(attr: iter) && dims.count(V: index)) {
204 order.push_back(Elt: index);
205 }
206 }
207 // Then the remaining parallel loops.
208 for (auto [index, iter] : llvm::enumerate(First: contract.getIteratorTypes())) {
209 if (vector::isParallelIterator(attr: iter) && !dims.count(V: index)) {
210 order.push_back(Elt: index);
211 }
212 }
213 return order;
214}
215
216/// Returns the target vector size for the target operation based on the native
217/// vector size specified with `m`, `n`, and `k`.
218static std::optional<SmallVector<int64_t>>
219getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
220 if (auto contract = dyn_cast<vector::ContractionOp>(Val: op)) {
221 int64_t contractRank = contract.getIteratorTypes().size();
222 if (contractRank < 3)
223 return std::nullopt;
224 SmallVector<int64_t> nativeSize(contractRank - 3, 1);
225 nativeSize.append(IL: {m, n, k});
226 return nativeSize;
227 }
228 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(Val: op)) {
229 int64_t writeRank = writeOp.getVectorType().getRank();
230 if (writeRank < 2)
231 return std::nullopt;
232 SmallVector<int64_t> nativeSize(writeRank - 2, 1);
233 nativeSize.append(IL: {m, n});
234 return nativeSize;
235 }
236 if (auto readOp = dyn_cast<vector::TransferReadOp>(Val: op)) {
237 // Transfer read ops may need different shapes based on how they are being
238 // used. For simplicity just match the shape used by the extract strided op.
239 VectorType sliceType;
240 for (Operation *users : op->getUsers()) {
241 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(Val: users);
242 if (!extract)
243 return std::nullopt;
244 auto vecType = cast<VectorType>(Val: extract.getResult().getType());
245 if (sliceType && sliceType != vecType)
246 return std::nullopt;
247 sliceType = vecType;
248 }
249 return llvm::to_vector(Range: sliceType.getShape());
250 }
251 if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
252 if (auto vecType = dyn_cast<VectorType>(Val: op->getResultTypes()[0])) {
253 // TODO: The condition for unrolling elementwise should be restricted
254 // only to operations that need unrolling (connected to the contract).
255 if (vecType.getRank() < 2)
256 return std::nullopt;
257
258 // First check whether there is a slice to infer the shape from. This is
259 // required for cases where the accumulator type differs from the input
260 // types, in which case we will see an `arith.ext_` between the contract
261 // and transfer_read which needs to be unrolled.
262 VectorType sliceType;
263 for (Operation *users : op->getUsers()) {
264 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(Val: users);
265 if (!extract)
266 return std::nullopt;
267 auto vecType = cast<VectorType>(Val: extract.getResult().getType());
268 if (sliceType && sliceType != vecType)
269 return std::nullopt;
270 sliceType = vecType;
271 }
272 if (sliceType)
273 return llvm::to_vector(Range: sliceType.getShape());
274
275 // Else unroll for trailing elementwise.
276 SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
277 // Map elementwise ops to the output shape.
278 nativeSize.append(IL: {m, n});
279 return nativeSize;
280 }
281 }
282 return std::nullopt;
283}
284
285void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
286 RewritePatternSet &patterns) {
287 auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
288 auto contract = dyn_cast<vector::ContractionOp>(Val: op);
289 if (!contract)
290 return std::nullopt;
291 return gpuMmaUnrollOrder(contract);
292 };
293
294 int64_t m = getM();
295 int64_t n = getN();
296 int64_t k = getK();
297 auto nativeShapeFn =
298 [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
299 return getSubgroupMmaNativeVectorSize(op, m, n, k);
300 };
301 vector::populateVectorUnrollPatterns(
302 patterns, options: vector::UnrollVectorOptions()
303 .setNativeShapeFn(nativeShapeFn)
304 .setUnrollTraversalOrderFn(unrollOrder));
305}
306
307//===----------------------------------------------------------------------===//
308// EliminateBarriersOp
309//===----------------------------------------------------------------------===//
310
311void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
312 populateGpuEliminateBarriersPatterns(patterns);
313}
314
315//===----------------------------------------------------------------------===//
316// Block and thread mapping utilities.
317//===----------------------------------------------------------------------===//
318
319namespace {
320/// Local types used for mapping verification.
321struct MappingKind {};
322struct BlockMappingKind : MappingKind {};
323struct ThreadMappingKind : MappingKind {};
324} // namespace
325
326static DiagnosedSilenceableFailure
327definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
328 Operation *target, const Twine &message) {
329 if (transformOp.has_value())
330 return transformOp->emitDefiniteFailure() << message;
331 return emitDefiniteFailure(op: target, message);
332}
333
334/// Check if given mapping attributes are one of the desired attributes
335template <typename MappingKindType>
336static DiagnosedSilenceableFailure
337checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
338 scf::ForallOp forallOp) {
339 if (!forallOp.getMapping().has_value()) {
340 return definiteFailureHelper(transformOp, target: forallOp,
341 message: "scf.forall op requires a mapping attribute");
342 }
343
344 bool hasBlockMapping = llvm::any_of(Range: forallOp.getMapping().value(),
345 P: llvm::IsaPred<GPUBlockMappingAttr>);
346 bool hasWarpgroupMapping = llvm::any_of(
347 Range: forallOp.getMapping().value(), P: llvm::IsaPred<GPUWarpgroupMappingAttr>);
348 bool hasWarpMapping = llvm::any_of(Range: forallOp.getMapping().value(),
349 P: llvm::IsaPred<GPUWarpMappingAttr>);
350 bool hasThreadMapping = llvm::any_of(Range: forallOp.getMapping().value(),
351 P: llvm::IsaPred<GPUThreadMappingAttr>);
352 bool hasLaneMapping = llvm::any_of(Range: forallOp.getMapping().value(),
353 P: llvm::IsaPred<GPULaneMappingAttr>);
354 int64_t countMappingTypes = 0;
355 countMappingTypes += hasBlockMapping ? 1 : 0;
356 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
357 countMappingTypes += hasWarpMapping ? 1 : 0;
358 countMappingTypes += hasThreadMapping ? 1 : 0;
359 countMappingTypes += hasLaneMapping ? 1 : 0;
360 if (countMappingTypes > 1) {
361 return definiteFailureHelper(
362 transformOp, target: forallOp,
363 message: "cannot mix different mapping types, use nesting");
364 }
365 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
366 !hasBlockMapping) {
367 return definiteFailureHelper(
368 transformOp, target: forallOp,
369 message: "scf.forall op requires a mapping attribute of kind 'block'");
370 }
371 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
372 !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
373 !hasWarpgroupMapping) {
374 return definiteFailureHelper(transformOp, target: forallOp,
375 message: "scf.forall op requires a mapping attribute "
376 "of kind 'thread' or 'warp'");
377 }
378
379 DenseSet<Attribute> seen;
380 for (Attribute map : forallOp.getMapping()->getValue()) {
381 if (seen.contains(V: map)) {
382 return definiteFailureHelper(
383 transformOp, target: forallOp,
384 message: "duplicate attribute, cannot map different loops "
385 "to the same mapping id");
386 }
387 seen.insert(V: map);
388 }
389
390 auto isLinear = [](DeviceMappingAttrInterface attr) {
391 return attr.isLinearMapping();
392 };
393 if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
394 !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
395 return definiteFailureHelper(
396 transformOp, target: forallOp,
397 message: "cannot mix linear and non-linear mapping modes");
398 }
399
400 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
401 forallOp.getDeviceMaskingAttr();
402 if (succeeded(Result: maybeMaskingAttr) && *maybeMaskingAttr &&
403 !forallOp.usesLinearMapping()) {
404 return definiteFailureHelper(
405 transformOp, target: forallOp,
406 message: "device masking is only available in linear mapping mode");
407 }
408
409 return DiagnosedSilenceableFailure::success();
410}
411
412template <typename MappingKindType>
413static DiagnosedSilenceableFailure
414verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
415 scf::ForallOp forallOp) {
416 // Check the types of the mapping attributes match.
417 DiagnosedSilenceableFailure typeRes =
418 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
419 if (!typeRes.succeeded())
420 return typeRes;
421
422 // Perform other non-types verifications.
423 if (!forallOp.isNormalized())
424 return definiteFailureHelper(transformOp, target: forallOp,
425 message: "unsupported non-normalized loops");
426 if (forallOp.getNumResults() > 0)
427 return definiteFailureHelper(transformOp, target: forallOp,
428 message: "only bufferized scf.forall can be mapped");
429 bool useLinearMapping = forallOp.usesLinearMapping();
430 // TODO: This would be more natural with support for Optional<EnumParameter>
431 // in GPUDeviceMappingAttr.
432 int64_t maxNumMappingsSupported =
433 useLinearMapping ? (getMaxEnumValForMappingId() -
434 static_cast<uint64_t>(MappingId::DimZ))
435 : 3;
436 if (forallOp.getRank() > maxNumMappingsSupported) {
437 return definiteFailureHelper(transformOp, target: forallOp,
438 message: "scf.forall with rank > ")
439 << maxNumMappingsSupported
440 << " does not lower for the specified mapping attribute type";
441 }
442 auto numParallelIterations =
443 getConstantIntValues(ofrs: forallOp.getMixedUpperBound());
444 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
445 return definiteFailureHelper(
446 transformOp, target: forallOp,
447 message: "requires statically sized, normalized forall op");
448 }
449 return DiagnosedSilenceableFailure::success();
450}
451
452/// Struct to return the result of the rewrite of a forall operation.
453struct ForallRewriteResult {
454 SmallVector<int64_t> mappingSizes;
455 SmallVector<Value> mappingIds;
456};
457
458/// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
459template <typename OpTy, typename OperationOrBlock>
460static void
461replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
462 OperationOrBlock *parent, Value replacement,
463 ArrayRef<int64_t> availableMappingSizes) {
464 parent->walk([&](OpTy idOp) {
465 if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
466 rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
467 });
468}
469
470static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
471 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
472 scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
473 ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
474 LDBG("--start rewriteOneForallCommonImpl");
475
476 // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
477 auto numParallelIterations =
478 getConstantIntValues(ofrs: forallOp.getMixedUpperBound());
479 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
480 "requires statically sized, normalized forall op");
481 SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
482 SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
483 forallOp.getDeviceMappingAttrs();
484 SetVector<Attribute> forallMappingAttrs;
485 forallMappingAttrs.insert_range(R&: forallMappingAttrsVec);
486 auto comparator = [](Attribute a, Attribute b) -> bool {
487 return cast<DeviceMappingAttrInterface>(Val&: a).getMappingId() <
488 cast<DeviceMappingAttrInterface>(Val&: b).getMappingId();
489 };
490
491 // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
492 // mapping all dimensions. In the 3-D mapping case we need to map all
493 // dimensions.
494 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
495 Val: *llvm::max_element(Range&: forallMappingAttrs, C: comparator));
496 DeviceMappingAttrInterface maxLinearMapping;
497 if (maxMapping.isLinearMapping())
498 maxLinearMapping = maxMapping;
499 for (auto attr : gpuIdBuilder.mappingAttributes) {
500 // If attr overflows, just skip.
501 if (maxLinearMapping && comparator(maxLinearMapping, attr))
502 continue;
503 // Try to insert. If element was already present, just continue.
504 if (!forallMappingAttrs.insert(X: attr))
505 continue;
506 // Otherwise, we have a new insertion without a size -> use size 1.
507 tmpMappingSizes.push_back(Elt: 1);
508 }
509 LDBG("----tmpMappingSizes extracted from scf.forall op: "
510 << llvm::interleaved(tmpMappingSizes));
511
512 // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
513 SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
514 keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator);
515 LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
516 LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
517
518 // Step 3. Generate the mappingIdOps using the provided generator.
519 Location loc = forallOp.getLoc();
520 OpBuilder::InsertionGuard guard(rewriter);
521 rewriter.setInsertionPoint(forallOp);
522 SmallVector<int64_t> originalBasis(availableMappingSizes);
523 bool originalBasisWasProvided = !originalBasis.empty();
524 if (!originalBasisWasProvided) {
525 LDBG("----originalBasis was not provided, deriving it and there will be no "
526 "predication");
527 originalBasis = forallMappingSizes;
528 while (originalBasis.size() < 3)
529 originalBasis.push_back(Elt: 1);
530 } else {
531 LDBG("----originalBasis was provided, using it, there will be predication");
532 }
533 LLVM_DEBUG(
534 llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: ");
535 llvm::dbgs() << "\n");
536
537 IdBuilderResult builderResult =
538 gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
539 if (!builderResult.errorMsg.empty())
540 return definiteFailureHelper(transformOp, target: forallOp, message: builderResult.errorMsg);
541
542 LLVM_DEBUG(DBGS() << builderResult);
543
544 // Step 4. Map the induction variables to the mappingIdOps, this may involve
545 // a permutation.
546 SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
547 IRMapping bvm;
548 for (auto [iv, dim] : llvm::zip_equal(
549 t: forallOp.getInductionVars(),
550 u: forallMappingAttrs.getArrayRef().take_front(N: forallOp.getRank()))) {
551 auto mappingAttr = cast<DeviceMappingAttrInterface>(Val: dim);
552 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
553 LDBG("----map: " << iv << " to " << peIdOp);
554 bvm.map(from: iv, to: peIdOp);
555 }
556
557 // Step 5. If the originalBasis is already known, create conditionals to
558 // predicate the region. Otherwise, the current forall determines the
559 // originalBasis and no predication occurs.
560 Value predicate;
561 if (originalBasisWasProvided) {
562 for (Value tmpPredicate : builderResult.predicateOps) {
563 predicate = predicate ? rewriter.create<arith::AndIOp>(location: loc, args&: predicate,
564 args&: tmpPredicate)
565 : tmpPredicate;
566 }
567 }
568
569 // Step 6. Move the body of forallOp.
570 // Erase the terminator first, it will not be used.
571 rewriter.eraseOp(op: forallOp.getTerminator());
572 Block *targetBlock;
573 Block::iterator insertionPoint;
574 if (predicate) {
575 // Step 6.a. If predicated, move at the beginning.
576 auto ifOp = rewriter.create<scf::IfOp>(location: loc, args&: predicate,
577 /*withElseRegion=*/args: false);
578 targetBlock = ifOp.thenBlock();
579 insertionPoint = ifOp.thenBlock()->begin();
580 } else {
581 // Step 6.b. Otherwise, move inline just at the rewriter insertion
582 // point.
583 targetBlock = forallOp->getBlock();
584 insertionPoint = rewriter.getInsertionPoint();
585 }
586 Block &sourceBlock = forallOp.getRegion().front();
587 targetBlock->getOperations().splice(where: insertionPoint,
588 L2&: sourceBlock.getOperations());
589
590 // Step 7. RAUW indices.
591 for (Value loopIndex : forallOp.getInductionVars()) {
592 Value threadIdx = bvm.lookup(from: loopIndex);
593 rewriter.replaceAllUsesWith(from: loopIndex, to: threadIdx);
594 }
595
596 // Step 8. Erase old op.
597 rewriter.eraseOp(op: forallOp);
598
599 LDBG("----result forallMappingSizes: "
600 << llvm::interleaved(forallMappingSizes));
601 LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
602
603 result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps};
604 return DiagnosedSilenceableFailure::success();
605}
606
607//===----------------------------------------------------------------------===//
608// MapForallToBlocks
609//===----------------------------------------------------------------------===//
610
611DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
612 RewriterBase &rewriter, TransformOpInterface transformOp,
613 scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
614 const GpuIdBuilder &gpuIdBuilder) {
615 LDBG("Start mapForallToBlocksImpl");
616
617 {
618 // GPU-specific verifications. There is no better place to anchor
619 // those right now: the ForallOp is target-independent and the transform
620 // op does not apply to individual ForallOp.
621 DiagnosedSilenceableFailure diag =
622 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
623 if (!diag.succeeded())
624 return diag;
625 }
626
627 Location loc = forallOp.getLoc();
628 Block *parentBlock = forallOp->getBlock();
629 Value zero;
630 {
631 // Create an early zero index value for replacements and immediately reset
632 // the insertion point.
633 OpBuilder::InsertionGuard guard(rewriter);
634 rewriter.setInsertionPointToStart(parentBlock);
635 zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
636 }
637
638 ForallRewriteResult rewriteResult;
639 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
640 rewriter, transformOp, forallOp,
641 /*availableMappingSizes=*/gridDims, result&: rewriteResult, gpuIdBuilder);
642
643 // Return if anything goes wrong, use silenceable failure as a match
644 // failure.
645 if (!diag.succeeded())
646 return diag;
647
648 // If gridDims was not provided already, set it from the return.
649 if (gridDims.empty()) {
650 gridDims = rewriteResult.mappingSizes;
651 while (gridDims.size() < 3)
652 gridDims.push_back(Elt: 1);
653 }
654 assert(gridDims.size() == 3 && "Need 3-D gridDims");
655
656 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
657 // Here, the result of mapping determines the available mapping sizes.
658 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parent: parentBlock, replacement: zero,
659 availableMappingSizes: rewriteResult.mappingSizes);
660
661 return DiagnosedSilenceableFailure::success();
662}
663
664DiagnosedSilenceableFailure
665mlir::transform::gpu::findTopLevelForallOp(Operation *target,
666 scf::ForallOp &topLevelForallOp,
667 TransformOpInterface transformOp) {
668 auto walkResult = target->walk(callback: [&](scf::ForallOp forallOp) {
669 if (forallOp->getParentOfType<scf::ForallOp>())
670 return WalkResult::advance();
671 if (topLevelForallOp)
672 // TODO: Handle multiple forall if they are independent.
673 return WalkResult::interrupt();
674 topLevelForallOp = forallOp;
675 return WalkResult::advance();
676 });
677
678 if (walkResult.wasInterrupted() || !topLevelForallOp)
679 return transformOp.emitSilenceableError()
680 << "could not find a unique topLevel scf.forall";
681 return DiagnosedSilenceableFailure::success();
682}
683
684DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
685 transform::TransformRewriter &rewriter, Operation *target,
686 ApplyToEachResultList &results, transform::TransformState &state) {
687 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(Val: target);
688 auto transformOp = cast<TransformOpInterface>(Val: getOperation());
689
690 if (!getGenerateGpuLaunch() && !gpuLaunch) {
691 DiagnosedSilenceableFailure diag =
692 emitSilenceableError()
693 << "Given target is not gpu.launch, set `generate_gpu_launch` "
694 "attribute";
695 diag.attachNote(loc: target->getLoc()) << "when applied to this payload op";
696 return diag;
697 }
698
699 scf::ForallOp topLevelForallOp;
700 DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
701 target, topLevelForallOp, transformOp);
702 if (!diag.succeeded()) {
703 diag.attachNote(loc: target->getLoc()) << "when applied to this payload op";
704 return diag;
705 }
706 assert(topLevelForallOp && "expect an scf.forall");
707
708 SmallVector<int64_t> gridDims{getGridDims()};
709 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
710 return transformOp.emitDefiniteFailure(message: "transform require size-3 mapping");
711
712 OpBuilder::InsertionGuard guard(rewriter);
713 rewriter.setInsertionPoint(topLevelForallOp);
714
715 // Generate gpu launch here and move the forall inside
716 if (getGenerateGpuLaunch()) {
717 DiagnosedSilenceableFailure diag =
718 createGpuLaunch(rewriter, loc: target->getLoc(), transformOp, launchOp&: gpuLaunch);
719 if (!diag.succeeded())
720 return diag;
721
722 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
723 Operation *newForallOp = rewriter.clone(op&: *topLevelForallOp);
724 rewriter.eraseOp(op: topLevelForallOp);
725 topLevelForallOp = cast<scf::ForallOp>(Val: newForallOp);
726 }
727
728 // The BlockIdBuilder adapts to whatever is thrown at it.
729 bool useLinearMapping = false;
730 if (topLevelForallOp.getMapping())
731 useLinearMapping = topLevelForallOp.usesLinearMapping();
732
733 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
734 topLevelForallOp.getDeviceMaskingAttr();
735 assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
736 assert((!*maybeMaskingAttr || useLinearMapping) &&
737 "masking requires linear mapping");
738
739 GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
740 *maybeMaskingAttr);
741
742 diag = mlir::transform::gpu::mapForallToBlocksImpl(
743 rewriter, transformOp, forallOp: topLevelForallOp, gridDims, gpuIdBuilder: gpuBlockIdBuilder);
744 if (!diag.succeeded())
745 return diag;
746
747 // Set the GPU launch configuration for the grid dims late, this is
748 // subject to IR inspection.
749 diag = alterGpuLaunch(rewriter, gpuLaunch,
750 transformOp: cast<TransformOpInterface>(Val: getOperation()), gridDimX: gridDims[0],
751 gridDimY: gridDims[1], gridDimZ: gridDims[2]);
752
753 results.push_back(op: gpuLaunch);
754 return diag;
755}
756
757LogicalResult transform::MapForallToBlocks::verify() {
758 if (!getGridDims().empty() && getGridDims().size() != 3) {
759 return emitOpError() << "transform requires empty or size-3 grid_dims";
760 }
761 return success();
762}
763
764//===----------------------------------------------------------------------===//
765// MapNestedForallToThreads
766//===----------------------------------------------------------------------===//
767
768static DiagnosedSilenceableFailure checkMappingSpec(
769 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
770 ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
771 int factor, bool useLinearMapping = false) {
772 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
773 auto diag = definiteFailureHelper(
774 transformOp, target: forallOp,
775 message: Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
776 Twine(factor));
777 return diag;
778 }
779 if (computeProduct(basis: numParallelIterations) * factor >
780 computeProduct(basis: blockOrGridSizes)) {
781 auto diag = definiteFailureHelper(
782 transformOp, target: forallOp,
783 message: Twine("the number of required parallel resources (blocks or "
784 "threads) ") +
785 Twine(computeProduct(basis: numParallelIterations) * factor) +
786 " overflows the number of available resources " +
787 Twine(computeProduct(basis: blockOrGridSizes)));
788 return diag;
789 }
790 return DiagnosedSilenceableFailure::success();
791}
792
793static DiagnosedSilenceableFailure
794getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
795 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
796 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
797 DeviceMappingAttrInterface mappingAttr =
798 forallOp.getDeviceMappingAttrs().front();
799 bool useLinearMapping = mappingAttr.isLinearMapping();
800
801 // Sanity checks that may result in runtime verification errors.
802 auto numParallelIterations =
803 getConstantIntValues(ofrs: (forallOp.getMixedUpperBound()));
804 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
805 return definiteFailureHelper(
806 transformOp, target: forallOp,
807 message: "requires statically sized, normalized forall op");
808 }
809 int64_t factor = 1;
810 if (isa<GPUWarpgroupMappingAttr>(Val: mappingAttr)) {
811 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
812 } else if (isa<GPUWarpMappingAttr>(Val: mappingAttr)) {
813 factor = warpSize;
814 }
815 DiagnosedSilenceableFailure diag =
816 checkMappingSpec(transformOp, forallOp, numParallelIterations: numParallelIterations.value(),
817 blockOrGridSizes: blockSizes, factor, useLinearMapping);
818 if (!diag.succeeded())
819 return diag;
820
821 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
822 forallOp.getDeviceMaskingAttr();
823 assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
824 assert((!*maybeMaskingAttr || useLinearMapping) &&
825 "masking requires linear mapping");
826
827 // Start mapping.
828 MLIRContext *ctx = forallOp.getContext();
829 gpuIdBuilder =
830 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
831 .Case(caseFn: [&](GPUWarpgroupMappingAttr) {
832 return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
833 *maybeMaskingAttr);
834 })
835 .Case(caseFn: [&](GPUWarpMappingAttr) {
836 return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
837 *maybeMaskingAttr);
838 })
839 .Case(caseFn: [&](GPUThreadMappingAttr) {
840 return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
841 })
842 .Case(caseFn: [&](GPULaneMappingAttr) {
843 return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
844 *maybeMaskingAttr);
845 })
846 .Default(defaultFn: [&](DeviceMappingAttrInterface) -> GpuIdBuilder {
847 llvm_unreachable("unknown mapping attribute");
848 });
849 return DiagnosedSilenceableFailure::success();
850}
851
852DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
853 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
854 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
855 bool syncAfterDistribute) {
856
857 {
858 // GPU-specific verifications. There is no better place to anchor
859 // those right now: the ForallOp is target-independent and the transform
860 // op does not apply to individual ForallOp.
861 DiagnosedSilenceableFailure diag =
862 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
863 if (!diag.succeeded())
864 return diag;
865 }
866
867 GpuIdBuilder gpuIdBuilder;
868 {
869 // Try to construct the id builder, if it fails, return.
870 DiagnosedSilenceableFailure diag = getThreadIdBuilder(
871 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
872 if (!diag.succeeded())
873 return diag;
874 }
875
876 Location loc = forallOp.getLoc();
877 OpBuilder::InsertionGuard g(rewriter);
878 // Insert after to allow for syncthreads after `forall` is erased.
879 rewriter.setInsertionPointAfter(forallOp);
880 ForallRewriteResult rewriteResult;
881 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
882 rewriter, transformOp, forallOp, availableMappingSizes: blockSizes, result&: rewriteResult, gpuIdBuilder);
883 if (!diag.succeeded())
884 return diag;
885 // Add a syncthreads if needed. TODO: warpsync
886 if (syncAfterDistribute)
887 rewriter.create<BarrierOp>(location: loc);
888
889 return DiagnosedSilenceableFailure::success();
890}
891
892DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
893 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
894 Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
895 bool syncAfterDistribute) {
896 LDBG("Start mapNestedForallToThreadsImpl");
897 if (blockDims.size() != 3) {
898 return definiteFailureHelper(transformOp, target,
899 message: "requires size-3 thread mapping");
900 }
901
902 // Create an early zero index value for replacements.
903 Location loc = target->getLoc();
904 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
905 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
906 WalkResult walkResult = target->walk(callback: [&](scf::ForallOp forallOp) {
907 diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
908 rewriter, transformOp, forallOp, blockSizes: blockDims, warpSize,
909 syncAfterDistribute);
910 if (diag.isDefiniteFailure())
911 return WalkResult::interrupt();
912 if (diag.succeeded())
913 return WalkResult::skip();
914 return WalkResult::advance();
915 });
916 if (walkResult.wasInterrupted())
917 return diag;
918
919 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
920 // Here, the result of mapping determines the available mapping sizes.
921 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, parent: target, replacement: zero,
922 availableMappingSizes: blockDims);
923
924 return DiagnosedSilenceableFailure::success();
925}
926
927DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
928 transform::TransformRewriter &rewriter, Operation *target,
929 ApplyToEachResultList &results, TransformState &state) {
930 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(Val: target);
931 auto transformOp = cast<TransformOpInterface>(Val: getOperation());
932
933 // Basic high-level verifications.
934 if (!gpuLaunch)
935 return emitSilenceableError() << "Given target is not a gpu.launch";
936
937 // Mapping to block ids.
938 SmallVector<int64_t> blockDims{getBlockDims()};
939 DiagnosedSilenceableFailure diag =
940 checkGpuLimits(transformOp, gridDimX: std::nullopt, gridDimY: std::nullopt, gridDimZ: std::nullopt,
941 blockDimX: blockDims[0], blockDimY: blockDims[1], blockDimZ: blockDims[2]);
942 if (diag.isSilenceableFailure()) {
943 diag.attachNote(loc: getLoc()) << getBlockDimsAttrName() << " is too large";
944 return diag;
945 }
946
947 // Set the GPU launch configuration for the block dims early, this is not
948 // subject to IR inspection.
949 diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, gridDimX: std::nullopt,
950 gridDimY: std::nullopt, gridDimZ: std::nullopt, blockDimX: blockDims[0], blockDimY: blockDims[1],
951 blockDimZ: blockDims[2]);
952
953 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
954 diag =
955 mapNestedForallToThreadsImpl(rewriter, transformOp, target: gpuLaunch, blockDims,
956 warpSize: getWarpSize(), syncAfterDistribute: getSyncAfterDistribute());
957
958 results.push_back(op: gpuLaunch.getOperation());
959 return diag;
960}
961
962//===----------------------------------------------------------------------===//
963// Transform op registration
964//===----------------------------------------------------------------------===//
965
966namespace {
967/// Registers new ops and declares PDL as dependent dialect since the
968/// additional ops are using PDL types for operands and results.
969class GPUTransformDialectExtension
970 : public transform::TransformDialectExtension<
971 GPUTransformDialectExtension> {
972public:
973 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
974
975 GPUTransformDialectExtension() {
976 declareGeneratedDialect<GPUDialect>();
977 declareGeneratedDialect<amdgpu::AMDGPUDialect>();
978 declareGeneratedDialect<arith::ArithDialect>();
979 declareGeneratedDialect<scf::SCFDialect>();
980 registerTransformOps<
981#define GET_OP_LIST
982#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
983 >();
984 }
985};
986} // namespace
987
988#define GET_OP_CLASSES
989#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
990
991void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
992 registry.addExtensions<GPUTransformDialectExtension>();
993}
994

source code of mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp