1//===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/GPU/TransformOps/Utils.h"
20#include "mlir/Dialect/GPU/Transforms/Passes.h"
21#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/Transform/IR/TransformDialect.h"
26#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
27#include "mlir/Dialect/Utils/IndexingUtils.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/Builders.h"
32#include "mlir/IR/BuiltinAttributes.h"
33#include "mlir/IR/IRMapping.h"
34#include "mlir/IR/MLIRContext.h"
35#include "mlir/IR/OpDefinition.h"
36#include "mlir/IR/Visitors.h"
37#include "mlir/Support/LLVM.h"
38#include "mlir/Transforms/DialectConversion.h"
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/Debug.h"
43#include "llvm/Support/ErrorHandling.h"
44#include "llvm/Support/InterleavedRange.h"
45#include <type_traits>
46
47using namespace mlir;
48using namespace mlir::gpu;
49using namespace mlir::transform;
50using namespace mlir::transform::gpu;
51
52#define DEBUG_TYPE "gpu-transforms"
53#define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
54
55#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
56#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
57#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
58
59//===----------------------------------------------------------------------===//
60// Apply...ConversionPatternsOp
61//===----------------------------------------------------------------------===//
62
63void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
64 TypeConverter &typeConverter, RewritePatternSet &patterns) {
65 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
66 // NVVM uses alloca in the default address space to represent private
67 // memory allocations, so drop private annotations. NVVM uses address
68 // space 3 for shared memory. NVVM uses the default address space to
69 // represent global memory.
70 // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
71 // TODO: We should have a single to_nvvm_type_converter.
72 populateGpuMemorySpaceAttributeConversions(
73 llvmTypeConverter, [](AddressSpace space) -> unsigned {
74 switch (space) {
75 case AddressSpace::Global:
76 return static_cast<unsigned>(
77 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
78 case AddressSpace::Workgroup:
79 return static_cast<unsigned>(
80 NVVM::NVVMMemorySpace::kSharedMemorySpace);
81 case AddressSpace::Private:
82 return 0;
83 }
84 llvm_unreachable("unknown address space enum value");
85 return 0;
86 });
87 // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
88 // TODO: We should have a single to_nvvm_type_converter.
89 llvmTypeConverter.addConversion(
90 [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
91 // Set higher benefit, so patterns will run before generic LLVM lowering.
92 populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns,
93 getBenefit());
94}
95
96LogicalResult
97transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
98 transform::TypeConverterBuilderOpInterface builder) {
99 if (builder.getTypeConverterType() != "LLVMTypeConverter")
100 return emitOpError("expected LLVMTypeConverter");
101 return success();
102}
103
104void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
105 TypeConverter &typeConverter, RewritePatternSet &patterns) {
106 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
107 populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
108}
109
110LogicalResult
111transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
112 transform::TypeConverterBuilderOpInterface builder) {
113 if (builder.getTypeConverterType() != "LLVMTypeConverter")
114 return emitOpError("expected LLVMTypeConverter");
115 return success();
116}
117
118void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
119 populatePatterns(TypeConverter &typeConverter,
120 RewritePatternSet &patterns) {
121 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
122 populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
123}
124
125LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
126 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
127 if (builder.getTypeConverterType() != "LLVMTypeConverter")
128 return emitOpError("expected LLVMTypeConverter");
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// Apply...PatternsOp
134//===----------------------------------------------------------------------===//s
135
136void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
137 populateGpuRewritePatterns(patterns);
138}
139
140void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
141 RewritePatternSet &patterns) {
142 populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
143}
144
145//===----------------------------------------------------------------------===//
146// ApplyUnrollVectorsSubgroupMmaOp
147//===----------------------------------------------------------------------===//
148
149/// Pick an unrolling order that will allow tensorcore operation to reuse LHS
150/// register.
151static std::optional<SmallVector<int64_t>>
152gpuMmaUnrollOrder(vector::ContractionOp contract) {
153 SmallVector<int64_t> order;
154 // First make reduction the outer dimensions.
155 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
156 if (vector::isReductionIterator(iter)) {
157 order.push_back(index);
158 }
159 }
160
161 llvm::SmallDenseSet<int64_t> dims;
162 for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
163 dims.insert(cast<AffineDimExpr>(expr).getPosition());
164 }
165 // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
166 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
167 if (vector::isParallelIterator(iter) && dims.count(index)) {
168 order.push_back(index);
169 }
170 }
171 // Then the remaining parallel loops.
172 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
173 if (vector::isParallelIterator(iter) && !dims.count(index)) {
174 order.push_back(index);
175 }
176 }
177 return order;
178}
179
180/// Returns the target vector size for the target operation based on the native
181/// vector size specified with `m`, `n`, and `k`.
182static std::optional<SmallVector<int64_t>>
183getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
184 if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
185 int64_t contractRank = contract.getIteratorTypes().size();
186 if (contractRank < 3)
187 return std::nullopt;
188 SmallVector<int64_t> nativeSize(contractRank - 3, 1);
189 nativeSize.append(IL: {m, n, k});
190 return nativeSize;
191 }
192 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
193 int64_t writeRank = writeOp.getVectorType().getRank();
194 if (writeRank < 2)
195 return std::nullopt;
196 SmallVector<int64_t> nativeSize(writeRank - 2, 1);
197 nativeSize.append(IL: {m, n});
198 return nativeSize;
199 }
200 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
201 // Transfer read ops may need different shapes based on how they are being
202 // used. For simplicity just match the shape used by the extract strided op.
203 VectorType sliceType;
204 for (Operation *users : op->getUsers()) {
205 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
206 if (!extract)
207 return std::nullopt;
208 auto vecType = cast<VectorType>(extract.getResult().getType());
209 if (sliceType && sliceType != vecType)
210 return std::nullopt;
211 sliceType = vecType;
212 }
213 return llvm::to_vector(sliceType.getShape());
214 }
215 if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
216 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
217 // TODO: The condition for unrolling elementwise should be restricted
218 // only to operations that need unrolling (connected to the contract).
219 if (vecType.getRank() < 2)
220 return std::nullopt;
221
222 // First check whether there is a slice to infer the shape from. This is
223 // required for cases where the accumulator type differs from the input
224 // types, in which case we will see an `arith.ext_` between the contract
225 // and transfer_read which needs to be unrolled.
226 VectorType sliceType;
227 for (Operation *users : op->getUsers()) {
228 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
229 if (!extract)
230 return std::nullopt;
231 auto vecType = cast<VectorType>(extract.getResult().getType());
232 if (sliceType && sliceType != vecType)
233 return std::nullopt;
234 sliceType = vecType;
235 }
236 if (sliceType)
237 return llvm::to_vector(sliceType.getShape());
238
239 // Else unroll for trailing elementwise.
240 SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
241 // Map elementwise ops to the output shape.
242 nativeSize.append(IL: {m, n});
243 return nativeSize;
244 }
245 }
246 return std::nullopt;
247}
248
249void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
250 RewritePatternSet &patterns) {
251 auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
252 auto contract = dyn_cast<vector::ContractionOp>(op);
253 if (!contract)
254 return std::nullopt;
255 return gpuMmaUnrollOrder(contract);
256 };
257
258 int64_t m = getM();
259 int64_t n = getN();
260 int64_t k = getK();
261 auto nativeShapeFn =
262 [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
263 return getSubgroupMmaNativeVectorSize(op, m, n, k);
264 };
265 vector::populateVectorUnrollPatterns(
266 patterns, vector::UnrollVectorOptions()
267 .setNativeShapeFn(nativeShapeFn)
268 .setUnrollTraversalOrderFn(unrollOrder));
269}
270
271//===----------------------------------------------------------------------===//
272// EliminateBarriersOp
273//===----------------------------------------------------------------------===//
274
275void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
276 populateGpuEliminateBarriersPatterns(patterns);
277}
278
279//===----------------------------------------------------------------------===//
280// Block and thread mapping utilities.
281//===----------------------------------------------------------------------===//
282
283namespace {
284/// Local types used for mapping verification.
285struct MappingKind {};
286struct BlockMappingKind : MappingKind {};
287struct ThreadMappingKind : MappingKind {};
288} // namespace
289
290static DiagnosedSilenceableFailure
291definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
292 Operation *target, const Twine &message) {
293 if (transformOp.has_value())
294 return transformOp->emitDefiniteFailure() << message;
295 return emitDefiniteFailure(op: target, message);
296}
297
298/// Check if given mapping attributes are one of the desired attributes
299template <typename MappingKindType>
300static DiagnosedSilenceableFailure
301checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
302 scf::ForallOp forallOp) {
303 if (!forallOp.getMapping().has_value()) {
304 return definiteFailureHelper(transformOp, forallOp,
305 "scf.forall op requires a mapping attribute");
306 }
307
308 bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
309 llvm::IsaPred<GPUBlockMappingAttr>);
310 bool hasWarpgroupMapping = llvm::any_of(
311 forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
312 bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
313 llvm::IsaPred<GPUWarpMappingAttr>);
314 bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
315 llvm::IsaPred<GPUThreadMappingAttr>);
316 int64_t countMappingTypes = 0;
317 countMappingTypes += hasBlockMapping ? 1 : 0;
318 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
319 countMappingTypes += hasWarpMapping ? 1 : 0;
320 countMappingTypes += hasThreadMapping ? 1 : 0;
321 if (countMappingTypes > 1) {
322 return definiteFailureHelper(
323 transformOp, forallOp,
324 "cannot mix different mapping types, use nesting");
325 }
326 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
327 !hasBlockMapping) {
328 return definiteFailureHelper(
329 transformOp, forallOp,
330 "scf.forall op requires a mapping attribute of kind 'block'");
331 }
332 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
333 !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
334 return definiteFailureHelper(transformOp, forallOp,
335 "scf.forall op requires a mapping attribute "
336 "of kind 'thread' or 'warp'");
337 }
338
339 DenseSet<Attribute> seen;
340 for (Attribute map : forallOp.getMapping()->getValue()) {
341 if (seen.contains(map)) {
342 return definiteFailureHelper(
343 transformOp, forallOp,
344 "duplicate attribute, cannot map different loops "
345 "to the same mapping id");
346 }
347 seen.insert(map);
348 }
349
350 auto isLinear = [](Attribute a) {
351 return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
352 };
353 if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
354 !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
355 return definiteFailureHelper(
356 transformOp, forallOp,
357 "cannot mix linear and non-linear mapping modes");
358 }
359
360 return DiagnosedSilenceableFailure::success();
361}
362
363template <typename MappingKindType>
364static DiagnosedSilenceableFailure
365verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
366 scf::ForallOp forallOp) {
367 // Check the types of the mapping attributes match.
368 DiagnosedSilenceableFailure typeRes =
369 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
370 if (!typeRes.succeeded())
371 return typeRes;
372
373 // Perform other non-types verifications.
374 if (!forallOp.isNormalized())
375 return definiteFailureHelper(transformOp, forallOp,
376 "unsupported non-normalized loops");
377 if (forallOp.getNumResults() > 0)
378 return definiteFailureHelper(transformOp, forallOp,
379 "only bufferized scf.forall can be mapped");
380 bool useLinearMapping = cast<DeviceMappingAttrInterface>(
381 forallOp.getMapping()->getValue().front())
382 .isLinearMapping();
383 // TODO: This would be more natural with support for Optional<EnumParameter>
384 // in GPUDeviceMappingAttr.
385 int64_t maxNumMappingsSupported =
386 useLinearMapping ? (getMaxEnumValForMappingId() -
387 static_cast<uint64_t>(MappingId::DimZ))
388 : 3;
389 if (forallOp.getRank() > maxNumMappingsSupported) {
390 return definiteFailureHelper(transformOp, forallOp,
391 "scf.forall with rank > ")
392 << maxNumMappingsSupported
393 << " does not lower for the specified mapping attribute type";
394 }
395 auto numParallelIterations =
396 getConstantIntValues(forallOp.getMixedUpperBound());
397 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
398 return definiteFailureHelper(
399 transformOp, forallOp,
400 "requires statically sized, normalized forall op");
401 }
402 return DiagnosedSilenceableFailure::success();
403}
404
405/// Struct to return the result of the rewrite of a forall operation.
406struct ForallRewriteResult {
407 SmallVector<int64_t> mappingSizes;
408 SmallVector<Value> mappingIds;
409};
410
411/// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
412template <typename OpTy, typename OperationOrBlock>
413static void
414replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
415 OperationOrBlock *parent, Value replacement,
416 ArrayRef<int64_t> availableMappingSizes) {
417 parent->walk([&](OpTy idOp) {
418 if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
419 rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
420 });
421}
422
423static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
424 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
425 scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
426 ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
427 LDBG("--start rewriteOneForallCommonImpl");
428
429 // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
430 auto numParallelIterations =
431 getConstantIntValues(forallOp.getMixedUpperBound());
432 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
433 "requires statically sized, normalized forall op");
434 SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
435 SetVector<Attribute> forallMappingAttrs;
436 forallMappingAttrs.insert_range(forallOp.getMapping()->getValue());
437 auto comparator = [](Attribute a, Attribute b) -> bool {
438 return cast<DeviceMappingAttrInterface>(a).getMappingId() <
439 cast<DeviceMappingAttrInterface>(b).getMappingId();
440 };
441
442 // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
443 // mapping all dimensions. In the 3-D mapping case we need to map all
444 // dimensions.
445 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
446 *llvm::max_element(forallMappingAttrs, comparator));
447 DeviceMappingAttrInterface maxLinearMapping;
448 if (maxMapping.isLinearMapping())
449 maxLinearMapping = maxMapping;
450 for (auto attr : gpuIdBuilder.mappingAttributes) {
451 // If attr overflows, just skip.
452 if (maxLinearMapping && comparator(maxLinearMapping, attr))
453 continue;
454 // Try to insert. If element was already present, just continue.
455 if (!forallMappingAttrs.insert(attr))
456 continue;
457 // Otherwise, we have a new insertion without a size -> use size 1.
458 tmpMappingSizes.push_back(1);
459 }
460 LDBG("----tmpMappingSizes extracted from scf.forall op: "
461 << llvm::interleaved(tmpMappingSizes));
462
463 // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
464 SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
465 keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator);
466 LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
467 LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
468
469 // Step 3. Generate the mappingIdOps using the provided generator.
470 Location loc = forallOp.getLoc();
471 OpBuilder::InsertionGuard guard(rewriter);
472 rewriter.setInsertionPoint(forallOp);
473 SmallVector<int64_t> originalBasis(availableMappingSizes);
474 bool originalBasisWasProvided = !originalBasis.empty();
475 if (!originalBasisWasProvided) {
476 originalBasis = forallMappingSizes;
477 while (originalBasis.size() < 3)
478 originalBasis.push_back(Elt: 1);
479 }
480
481 IdBuilderResult builderResult =
482 gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
483
484 // Step 4. Map the induction variables to the mappingIdOps, this may involve
485 // a permutation.
486 SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
487 IRMapping bvm;
488 for (auto [iv, dim] : llvm::zip_equal(
489 forallOp.getInductionVars(),
490 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
491 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
492 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
493 bvm.map(iv, peIdOp);
494 }
495
496 // Step 5. If the originalBasis is already known, create conditionals to
497 // predicate the region. Otherwise, the current forall determines the
498 // originalBasis and no predication occurs.
499 Value predicate;
500 if (originalBasisWasProvided) {
501 SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
502 SmallVector<int64_t> availableMappingSizes =
503 builderResult.availableMappingSizes;
504 SmallVector<Value> activeIdOps = builderResult.activeIdOps;
505 LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
506 LDBG("----availableMappingSizes: "
507 << llvm::interleaved(availableMappingSizes));
508 LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
509 for (auto [activeId, activeMappingSize, availableMappingSize] :
510 llvm::zip_equal(activeIdOps, activeMappingSizes,
511 availableMappingSizes)) {
512 if (activeMappingSize > availableMappingSize) {
513 return definiteFailureHelper(
514 transformOp, forallOp,
515 "Trying to map to fewer GPU threads than loop iterations but "
516 "overprovisioning is not yet supported. "
517 "Try additional tiling of the before mapping or map to more "
518 "threads.");
519 }
520 if (activeMappingSize == availableMappingSize)
521 continue;
522 Value idx =
523 rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
524 Value tmpPredicate = rewriter.create<arith::CmpIOp>(
525 loc, arith::CmpIPredicate::ult, activeId, idx);
526 LDBG("----predicate: " << tmpPredicate);
527 predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
528 tmpPredicate)
529 : tmpPredicate;
530 }
531 }
532
533 // Step 6. Move the body of forallOp.
534 // Erase the terminator first, it will not be used.
535 rewriter.eraseOp(op: forallOp.getTerminator());
536 Block *targetBlock;
537 Block::iterator insertionPoint;
538 if (predicate) {
539 // Step 6.a. If predicated, move at the beginning.
540 auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
541 /*withElseRegion=*/false);
542 targetBlock = ifOp.thenBlock();
543 insertionPoint = ifOp.thenBlock()->begin();
544 } else {
545 // Step 6.b. Otherwise, move inline just at the rewriter insertion
546 // point.
547 targetBlock = forallOp->getBlock();
548 insertionPoint = rewriter.getInsertionPoint();
549 }
550 Block &sourceBlock = forallOp.getRegion().front();
551 targetBlock->getOperations().splice(where: insertionPoint,
552 L2&: sourceBlock.getOperations());
553
554 // Step 7. RAUW indices.
555 for (Value loopIndex : forallOp.getInductionVars()) {
556 Value threadIdx = bvm.lookup(loopIndex);
557 rewriter.replaceAllUsesWith(loopIndex, threadIdx);
558 }
559
560 // Step 8. Erase old op.
561 rewriter.eraseOp(op: forallOp);
562
563 LDBG("----result forallMappingSizes: "
564 << llvm::interleaved(forallMappingSizes));
565 LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
566
567 result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps};
568 return DiagnosedSilenceableFailure::success();
569}
570
571//===----------------------------------------------------------------------===//
572// MapForallToBlocks
573//===----------------------------------------------------------------------===//
574
575DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
576 RewriterBase &rewriter, TransformOpInterface transformOp,
577 scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
578 const GpuIdBuilder &gpuIdBuilder) {
579 LDBG("Start mapForallToBlocksImpl");
580
581 {
582 // GPU-specific verifications. There is no better place to anchor
583 // those right now: the ForallOp is target-independent and the transform
584 // op does not apply to individual ForallOp.
585 DiagnosedSilenceableFailure diag =
586 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
587 if (!diag.succeeded())
588 return diag;
589 }
590
591 Location loc = forallOp.getLoc();
592 Block *parentBlock = forallOp->getBlock();
593 Value zero;
594 {
595 // Create an early zero index value for replacements and immediately reset
596 // the insertion point.
597 OpBuilder::InsertionGuard guard(rewriter);
598 rewriter.setInsertionPointToStart(parentBlock);
599 zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
600 }
601
602 ForallRewriteResult rewriteResult;
603 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
604 rewriter, transformOp, forallOp,
605 /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
606
607 // Return if anything goes wrong, use silenceable failure as a match
608 // failure.
609 if (!diag.succeeded())
610 return diag;
611
612 // If gridDims was not provided already, set it from the return.
613 if (gridDims.empty()) {
614 gridDims = rewriteResult.mappingSizes;
615 while (gridDims.size() < 3)
616 gridDims.push_back(Elt: 1);
617 }
618 assert(gridDims.size() == 3 && "Need 3-D gridDims");
619
620 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
621 // Here, the result of mapping determines the available mapping sizes.
622 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
623 rewriteResult.mappingSizes);
624
625 return DiagnosedSilenceableFailure::success();
626}
627
628DiagnosedSilenceableFailure
629mlir::transform::gpu::findTopLevelForallOp(Operation *target,
630 scf::ForallOp &topLevelForallOp,
631 TransformOpInterface transformOp) {
632 auto walkResult = target->walk([&](scf::ForallOp forallOp) {
633 if (forallOp->getParentOfType<scf::ForallOp>())
634 return WalkResult::advance();
635 if (topLevelForallOp)
636 // TODO: Handle multiple forall if they are independent.
637 return WalkResult::interrupt();
638 topLevelForallOp = forallOp;
639 return WalkResult::advance();
640 });
641
642 if (walkResult.wasInterrupted() || !topLevelForallOp)
643 return transformOp.emitSilenceableError()
644 << "could not find a unique topLevel scf.forall";
645 return DiagnosedSilenceableFailure::success();
646}
647
648DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
649 transform::TransformRewriter &rewriter, Operation *target,
650 ApplyToEachResultList &results, transform::TransformState &state) {
651 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
652 auto transformOp = cast<TransformOpInterface>(getOperation());
653
654 if (!getGenerateGpuLaunch() && !gpuLaunch) {
655 DiagnosedSilenceableFailure diag =
656 emitSilenceableError()
657 << "Given target is not gpu.launch, set `generate_gpu_launch` "
658 "attribute";
659 diag.attachNote(target->getLoc()) << "when applied to this payload op";
660 return diag;
661 }
662
663 scf::ForallOp topLevelForallOp;
664 DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
665 target, topLevelForallOp, transformOp);
666 if (!diag.succeeded()) {
667 diag.attachNote(target->getLoc()) << "when applied to this payload op";
668 return diag;
669 }
670 assert(topLevelForallOp && "expect an scf.forall");
671
672 SmallVector<int64_t> gridDims{getGridDims()};
673 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
674 return transformOp.emitDefiniteFailure("transform require size-3 mapping");
675
676 OpBuilder::InsertionGuard guard(rewriter);
677 rewriter.setInsertionPoint(topLevelForallOp);
678
679 // Generate gpu launch here and move the forall inside
680 if (getGenerateGpuLaunch()) {
681 DiagnosedSilenceableFailure diag =
682 createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
683 if (!diag.succeeded())
684 return diag;
685
686 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
687 Operation *newForallOp = rewriter.clone(*topLevelForallOp);
688 rewriter.eraseOp(topLevelForallOp);
689 topLevelForallOp = cast<scf::ForallOp>(newForallOp);
690 }
691
692 // The BlockIdBuilder adapts to whatever is thrown at it.
693 bool useLinearMapping = false;
694 if (topLevelForallOp.getMapping()) {
695 auto mappingAttr = cast<DeviceMappingAttrInterface>(
696 topLevelForallOp.getMapping()->getValue().front());
697 useLinearMapping = mappingAttr.isLinearMapping();
698 }
699 GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
700
701 diag = mlir::transform::gpu::mapForallToBlocksImpl(
702 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
703 if (!diag.succeeded())
704 return diag;
705
706 // Set the GPU launch configuration for the grid dims late, this is
707 // subject to IR inspection.
708 diag = alterGpuLaunch(rewriter, gpuLaunch,
709 cast<TransformOpInterface>(getOperation()), gridDims[0],
710 gridDims[1], gridDims[2]);
711
712 results.push_back(gpuLaunch);
713 return diag;
714}
715
716LogicalResult transform::MapForallToBlocks::verify() {
717 if (!getGridDims().empty() && getGridDims().size() != 3) {
718 return emitOpError() << "transform requires empty or size-3 grid_dims";
719 }
720 return success();
721}
722
723//===----------------------------------------------------------------------===//
724// MapNestedForallToThreads
725//===----------------------------------------------------------------------===//
726
727static DiagnosedSilenceableFailure checkMappingSpec(
728 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
729 ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
730 int factor, bool useLinearMapping = false) {
731 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
732 auto diag = definiteFailureHelper(
733 transformOp, forallOp,
734 Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
735 Twine(factor));
736 return diag;
737 }
738 if (computeProduct(basis: numParallelIterations) * factor >
739 computeProduct(basis: blockOrGridSizes)) {
740 auto diag = definiteFailureHelper(
741 transformOp, forallOp,
742 Twine("the number of required parallel resources (blocks or "
743 "threads) ") +
744 Twine(computeProduct(basis: numParallelIterations) * factor) +
745 " overflows the number of available resources " +
746 Twine(computeProduct(basis: blockOrGridSizes)));
747 return diag;
748 }
749 return DiagnosedSilenceableFailure::success();
750}
751
752static DiagnosedSilenceableFailure
753getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
754 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
755 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
756 auto mappingAttr = cast<DeviceMappingAttrInterface>(
757 forallOp.getMapping()->getValue().front());
758 bool useLinearMapping = mappingAttr.isLinearMapping();
759
760 // Sanity checks that may result in runtime verification errors.
761 auto numParallelIterations =
762 getConstantIntValues((forallOp.getMixedUpperBound()));
763 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
764 return definiteFailureHelper(
765 transformOp, forallOp,
766 "requires statically sized, normalized forall op");
767 }
768 int64_t factor = 1;
769 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
770 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
771 } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
772 factor = warpSize;
773 }
774 DiagnosedSilenceableFailure diag =
775 checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
776 blockSizes, factor, useLinearMapping);
777 if (!diag.succeeded())
778 return diag;
779
780 // Start mapping.
781 MLIRContext *ctx = forallOp.getContext();
782 gpuIdBuilder =
783 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
784 .Case([&](GPUWarpgroupMappingAttr) {
785 return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
786 })
787 .Case([&](GPUWarpMappingAttr) {
788 return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
789 })
790 .Case([&](GPUThreadMappingAttr) {
791 return GpuThreadIdBuilder(ctx, useLinearMapping);
792 })
793 .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
794 llvm_unreachable("unknown mapping attribute");
795 });
796 return DiagnosedSilenceableFailure::success();
797}
798
799DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
800 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
801 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
802 bool syncAfterDistribute) {
803
804 {
805 // GPU-specific verifications. There is no better place to anchor
806 // those right now: the ForallOp is target-independent and the transform
807 // op does not apply to individual ForallOp.
808 DiagnosedSilenceableFailure diag =
809 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
810 if (!diag.succeeded())
811 return diag;
812 }
813
814 GpuIdBuilder gpuIdBuilder;
815 {
816 // Try to construct the id builder, if it fails, return.
817 DiagnosedSilenceableFailure diag = getThreadIdBuilder(
818 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
819 if (!diag.succeeded())
820 return diag;
821 }
822
823 Location loc = forallOp.getLoc();
824 OpBuilder::InsertionGuard g(rewriter);
825 // Insert after to allow for syncthreads after `forall` is erased.
826 rewriter.setInsertionPointAfter(forallOp);
827 ForallRewriteResult rewriteResult;
828 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
829 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
830 if (!diag.succeeded())
831 return diag;
832 // Add a syncthreads if needed. TODO: warpsync
833 if (syncAfterDistribute)
834 rewriter.create<BarrierOp>(loc);
835
836 return DiagnosedSilenceableFailure::success();
837}
838
839DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
840 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
841 Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
842 bool syncAfterDistribute) {
843 LDBG("Start mapNestedForallToThreadsImpl");
844 if (blockDims.size() != 3) {
845 return definiteFailureHelper(transformOp, target,
846 message: "requires size-3 thread mapping");
847 }
848
849 // Create an early zero index value for replacements.
850 Location loc = target->getLoc();
851 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
852 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
853 WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
854 diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
855 rewriter, transformOp, forallOp, blockDims, warpSize,
856 syncAfterDistribute);
857 if (diag.isDefiniteFailure())
858 return WalkResult::interrupt();
859 if (diag.succeeded())
860 return WalkResult::skip();
861 return WalkResult::advance();
862 });
863 if (walkResult.wasInterrupted())
864 return diag;
865
866 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
867 // Here, the result of mapping determines the available mapping sizes.
868 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
869 blockDims);
870
871 return DiagnosedSilenceableFailure::success();
872}
873
874DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
875 transform::TransformRewriter &rewriter, Operation *target,
876 ApplyToEachResultList &results, TransformState &state) {
877 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
878 auto transformOp = cast<TransformOpInterface>(getOperation());
879
880 // Basic high-level verifications.
881 if (!gpuLaunch)
882 return emitSilenceableError() << "Given target is not a gpu.launch";
883
884 // Mapping to block ids.
885 SmallVector<int64_t> blockDims{getBlockDims()};
886 DiagnosedSilenceableFailure diag =
887 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
888 blockDims[0], blockDims[1], blockDims[2]);
889 if (diag.isSilenceableFailure()) {
890 diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
891 return diag;
892 }
893
894 // Set the GPU launch configuration for the block dims early, this is not
895 // subject to IR inspection.
896 diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
897 std::nullopt, std::nullopt, blockDims[0], blockDims[1],
898 blockDims[2]);
899
900 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
901 diag =
902 mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
903 getWarpSize(), getSyncAfterDistribute());
904
905 results.push_back(gpuLaunch.getOperation());
906 return diag;
907}
908
909//===----------------------------------------------------------------------===//
910// Transform op registration
911//===----------------------------------------------------------------------===//
912
913namespace {
914/// Registers new ops and declares PDL as dependent dialect since the
915/// additional ops are using PDL types for operands and results.
916class GPUTransformDialectExtension
917 : public transform::TransformDialectExtension<
918 GPUTransformDialectExtension> {
919public:
920 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
921
922 GPUTransformDialectExtension() {
923 declareGeneratedDialect<GPUDialect>();
924 declareGeneratedDialect<amdgpu::AMDGPUDialect>();
925 declareGeneratedDialect<arith::ArithDialect>();
926 declareGeneratedDialect<scf::SCFDialect>();
927 registerTransformOps<
928#define GET_OP_LIST
929#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
930 >();
931 }
932};
933} // namespace
934
935#define GET_OP_CLASSES
936#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
937
938void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
939 registry.addExtensions<GPUTransformDialectExtension>();
940}
941

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp