1//===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/GPU/TransformOps/Utils.h"
19#include "mlir/Dialect/GPU/Transforms/Passes.h"
20#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Transform/IR/TransformDialect.h"
25#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
26#include "mlir/Dialect/Utils/IndexingUtils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
29#include "mlir/IR/AffineExpr.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinAttributes.h"
32#include "mlir/IR/IRMapping.h"
33#include "mlir/IR/MLIRContext.h"
34#include "mlir/IR/OpDefinition.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "mlir/Transforms/DialectConversion.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/ErrorHandling.h"
43#include <type_traits>
44
45using namespace mlir;
46using namespace mlir::gpu;
47using namespace mlir::transform;
48using namespace mlir::transform::gpu;
49
50#define DEBUG_TYPE "gpu-transforms"
51#define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
52
53#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
56
57//===----------------------------------------------------------------------===//
58// Apply...ConversionPatternsOp
59//===----------------------------------------------------------------------===//
60
61void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62 TypeConverter &typeConverter, RewritePatternSet &patterns) {
63 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64 // NVVM uses alloca in the default address space to represent private
65 // memory allocations, so drop private annotations. NVVM uses address
66 // space 3 for shared memory. NVVM uses the default address space to
67 // represent global memory.
68 // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69 // TODO: We should have a single to_nvvm_type_converter.
70 populateGpuMemorySpaceAttributeConversions(
71 llvmTypeConverter, [](AddressSpace space) -> unsigned {
72 switch (space) {
73 case AddressSpace::Global:
74 return static_cast<unsigned>(
75 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
76 case AddressSpace::Workgroup:
77 return static_cast<unsigned>(
78 NVVM::NVVMMemorySpace::kSharedMemorySpace);
79 case AddressSpace::Private:
80 return 0;
81 }
82 llvm_unreachable("unknown address space enum value");
83 return 0;
84 });
85 // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86 // TODO: We should have a single to_nvvm_type_converter.
87 llvmTypeConverter.addConversion(
88 [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89 populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
90}
91
92LogicalResult
93transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94 transform::TypeConverterBuilderOpInterface builder) {
95 if (builder.getTypeConverterType() != "LLVMTypeConverter")
96 return emitOpError("expected LLVMTypeConverter");
97 return success();
98}
99
100void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101 TypeConverter &typeConverter, RewritePatternSet &patterns) {
102 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
103 populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
104}
105
106LogicalResult
107transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108 transform::TypeConverterBuilderOpInterface builder) {
109 if (builder.getTypeConverterType() != "LLVMTypeConverter")
110 return emitOpError("expected LLVMTypeConverter");
111 return success();
112}
113
114void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115 populatePatterns(TypeConverter &typeConverter,
116 RewritePatternSet &patterns) {
117 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
118 populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
119}
120
121LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123 if (builder.getTypeConverterType() != "LLVMTypeConverter")
124 return emitOpError("expected LLVMTypeConverter");
125 return success();
126}
127
128//===----------------------------------------------------------------------===//
129// Apply...PatternsOp
130//===----------------------------------------------------------------------===//s
131
132void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
133 populateGpuRewritePatterns(patterns);
134}
135
136//===----------------------------------------------------------------------===//
137// ApplyUnrollVectorsSubgroupMmaOp
138//===----------------------------------------------------------------------===//
139
140/// Pick an unrolling order that will allow tensorcore operation to reuse LHS
141/// register.
142static std::optional<SmallVector<int64_t>>
143gpuMmaUnrollOrder(vector::ContractionOp contract) {
144 SmallVector<int64_t> order;
145 // First make reduction the outer dimensions.
146 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
147 if (vector::isReductionIterator(iter)) {
148 order.push_back(index);
149 }
150 }
151
152 llvm::SmallDenseSet<int64_t> dims;
153 for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
154 dims.insert(cast<AffineDimExpr>(expr).getPosition());
155 }
156 // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
157 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
158 if (vector::isParallelIterator(iter) && dims.count(index)) {
159 order.push_back(index);
160 }
161 }
162 // Then the remaining parallel loops.
163 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
164 if (vector::isParallelIterator(iter) && !dims.count(index)) {
165 order.push_back(index);
166 }
167 }
168 return order;
169}
170
171/// Returns the target vector size for the target operation based on the native
172/// vector size specified with `m`, `n`, and `k`.
173static std::optional<SmallVector<int64_t>>
174getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
175 if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
176 int64_t contractRank = contract.getIteratorTypes().size();
177 if (contractRank < 3)
178 return std::nullopt;
179 SmallVector<int64_t> nativeSize(contractRank - 3, 1);
180 nativeSize.append(IL: {m, n, k});
181 return nativeSize;
182 }
183 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184 int64_t writeRank = writeOp.getVectorType().getRank();
185 if (writeRank < 2)
186 return std::nullopt;
187 SmallVector<int64_t> nativeSize(writeRank - 2, 1);
188 nativeSize.append(IL: {m, n});
189 return nativeSize;
190 }
191 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
192 // Transfer read ops may need different shapes based on how they are being
193 // used. For simplicity just match the shape used by the extract strided op.
194 VectorType sliceType;
195 for (Operation *users : op->getUsers()) {
196 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
197 if (!extract)
198 return std::nullopt;
199 auto vecType = cast<VectorType>(extract.getResult().getType());
200 if (sliceType && sliceType != vecType)
201 return std::nullopt;
202 sliceType = vecType;
203 }
204 return llvm::to_vector(sliceType.getShape());
205 }
206 if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
207 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
208 // TODO: The condition for unrolling elementwise should be restricted
209 // only to operations that need unrolling (connected to the contract).
210 if (vecType.getRank() < 2)
211 return std::nullopt;
212
213 // First check whether there is a slice to infer the shape from. This is
214 // required for cases where the accumulator type differs from the input
215 // types, in which case we will see an `arith.ext_` between the contract
216 // and transfer_read which needs to be unrolled.
217 VectorType sliceType;
218 for (Operation *users : op->getUsers()) {
219 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
220 if (!extract)
221 return std::nullopt;
222 auto vecType = cast<VectorType>(extract.getResult().getType());
223 if (sliceType && sliceType != vecType)
224 return std::nullopt;
225 sliceType = vecType;
226 }
227 if (sliceType)
228 return llvm::to_vector(sliceType.getShape());
229
230 // Else unroll for trailing elementwise.
231 SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
232 // Map elementwise ops to the output shape.
233 nativeSize.append(IL: {m, n});
234 return nativeSize;
235 }
236 }
237 return std::nullopt;
238}
239
240void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
241 RewritePatternSet &patterns) {
242 auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
243 auto contract = dyn_cast<vector::ContractionOp>(op);
244 if (!contract)
245 return std::nullopt;
246 return gpuMmaUnrollOrder(contract);
247 };
248
249 int64_t m = getM();
250 int64_t n = getN();
251 int64_t k = getK();
252 auto nativeShapeFn =
253 [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
254 return getSubgroupMmaNativeVectorSize(op, m, n, k);
255 };
256 vector::populateVectorUnrollPatterns(
257 patterns, vector::UnrollVectorOptions()
258 .setNativeShapeFn(nativeShapeFn)
259 .setUnrollTraversalOrderFn(unrollOrder));
260}
261
262//===----------------------------------------------------------------------===//
263// EliminateBarriersOp
264//===----------------------------------------------------------------------===//
265
266void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
267 populateGpuEliminateBarriersPatterns(patterns);
268}
269
270//===----------------------------------------------------------------------===//
271// Block and thread mapping utilities.
272//===----------------------------------------------------------------------===//
273
274namespace {
275/// Local types used for mapping verification.
276struct MappingKind {};
277struct BlockMappingKind : MappingKind {};
278struct ThreadMappingKind : MappingKind {};
279} // namespace
280
281static DiagnosedSilenceableFailure
282definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
283 Operation *target, const Twine &message) {
284 if (transformOp.has_value())
285 return transformOp->emitDefiniteFailure() << message;
286 return emitDefiniteFailure(op: target, message);
287}
288
289/// Check if given mapping attributes are one of the desired attributes
290template <typename MappingKindType>
291static DiagnosedSilenceableFailure
292checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
293 scf::ForallOp forallOp) {
294 if (!forallOp.getMapping().has_value()) {
295 return definiteFailureHelper(transformOp, forallOp,
296 "scf.forall op requires a mapping attribute");
297 }
298
299 bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300 llvm::IsaPred<GPUBlockMappingAttr>);
301 bool hasWarpgroupMapping = llvm::any_of(
302 forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303 bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304 llvm::IsaPred<GPUWarpMappingAttr>);
305 bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306 llvm::IsaPred<GPUThreadMappingAttr>);
307 int64_t countMappingTypes = 0;
308 countMappingTypes += hasBlockMapping ? 1 : 0;
309 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
310 countMappingTypes += hasWarpMapping ? 1 : 0;
311 countMappingTypes += hasThreadMapping ? 1 : 0;
312 if (countMappingTypes > 1) {
313 return definiteFailureHelper(
314 transformOp, forallOp,
315 "cannot mix different mapping types, use nesting");
316 }
317 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
318 !hasBlockMapping) {
319 return definiteFailureHelper(
320 transformOp, forallOp,
321 "scf.forall op requires a mapping attribute of kind 'block'");
322 }
323 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
324 !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
325 return definiteFailureHelper(transformOp, forallOp,
326 "scf.forall op requires a mapping attribute "
327 "of kind 'thread' or 'warp'");
328 }
329
330 DenseSet<Attribute> seen;
331 for (Attribute map : forallOp.getMapping()->getValue()) {
332 if (seen.contains(map)) {
333 return definiteFailureHelper(
334 transformOp, forallOp,
335 "duplicate attribute, cannot map different loops "
336 "to the same mapping id");
337 }
338 seen.insert(map);
339 }
340
341 auto isLinear = [](Attribute a) {
342 return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
343 };
344 if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
345 !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
346 return definiteFailureHelper(
347 transformOp, forallOp,
348 "cannot mix linear and non-linear mapping modes");
349 }
350
351 return DiagnosedSilenceableFailure::success();
352}
353
354template <typename MappingKindType>
355static DiagnosedSilenceableFailure
356verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
357 scf::ForallOp forallOp) {
358 // Check the types of the mapping attributes match.
359 DiagnosedSilenceableFailure typeRes =
360 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
361 if (!typeRes.succeeded())
362 return typeRes;
363
364 // Perform other non-types verifications.
365 if (!forallOp.isNormalized())
366 return definiteFailureHelper(transformOp, forallOp,
367 "unsupported non-normalized loops");
368 if (forallOp.getNumResults() > 0)
369 return definiteFailureHelper(transformOp, forallOp,
370 "only bufferized scf.forall can be mapped");
371 bool useLinearMapping = cast<DeviceMappingAttrInterface>(
372 forallOp.getMapping()->getValue().front())
373 .isLinearMapping();
374 // TODO: This would be more natural with support for Optional<EnumParameter>
375 // in GPUDeviceMappingAttr.
376 int64_t maxNumMappingsSupported =
377 useLinearMapping ? (getMaxEnumValForMappingId() -
378 static_cast<uint64_t>(MappingId::DimZ))
379 : 3;
380 if (forallOp.getRank() > maxNumMappingsSupported) {
381 return definiteFailureHelper(transformOp, forallOp,
382 "scf.forall with rank > ")
383 << maxNumMappingsSupported
384 << " does not lower for the specified mapping attribute type";
385 }
386 auto numParallelIterations =
387 getConstantIntValues(forallOp.getMixedUpperBound());
388 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
389 return definiteFailureHelper(
390 transformOp, forallOp,
391 "requires statically sized, normalized forall op");
392 }
393 return DiagnosedSilenceableFailure::success();
394}
395
396/// Struct to return the result of the rewrite of a forall operation.
397struct ForallRewriteResult {
398 SmallVector<int64_t> mappingSizes;
399 SmallVector<Value> mappingIds;
400};
401
402/// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
403template <typename OpTy, typename OperationOrBlock>
404static void
405replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
406 OperationOrBlock *parent, Value replacement,
407 ArrayRef<int64_t> availableMappingSizes) {
408 parent->walk([&](OpTy idOp) {
409 if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
410 rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
411 });
412}
413
414static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
415 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
416 scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
417 ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
418 LDBG("--start rewriteOneForallCommonImpl");
419
420 // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
421 auto numParallelIterations =
422 getConstantIntValues(forallOp.getMixedUpperBound());
423 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
424 "requires statically sized, normalized forall op");
425 SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
426 SetVector<Attribute> forallMappingAttrs;
427 forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
428 forallOp.getMapping()->getValue().end());
429 auto comparator = [](Attribute a, Attribute b) -> bool {
430 return cast<DeviceMappingAttrInterface>(a).getMappingId() <
431 cast<DeviceMappingAttrInterface>(b).getMappingId();
432 };
433
434 // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
435 // mapping all dimensions. In the 3-D mapping case we need to map all
436 // dimensions.
437 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
438 *llvm::max_element(forallMappingAttrs, comparator));
439 DeviceMappingAttrInterface maxLinearMapping;
440 if (maxMapping.isLinearMapping())
441 maxLinearMapping = maxMapping;
442 for (auto attr : gpuIdBuilder.mappingAttributes) {
443 // If attr overflows, just skip.
444 if (maxLinearMapping && comparator(maxLinearMapping, attr))
445 continue;
446 // Try to insert. If element was already present, just continue.
447 if (!forallMappingAttrs.insert(attr))
448 continue;
449 // Otherwise, we have a new insertion without a size -> use size 1.
450 tmpMappingSizes.push_back(1);
451 }
452 LLVM_DEBUG(
453 llvm::interleaveComma(
454 tmpMappingSizes,
455 DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
456 llvm::dbgs() << "\n");
457
458 // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
459 SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
460 keys: forallMappingAttrs.getArrayRef(), values: tmpMappingSizes, compare: comparator);
461 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
462 DBGS() << "----forallMappingSizes: ");
463 llvm::dbgs() << "\n"; llvm::interleaveComma(
464 forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
465 llvm::dbgs() << "\n");
466
467 // Step 3. Generate the mappingIdOps using the provided generator.
468 Location loc = forallOp.getLoc();
469 OpBuilder::InsertionGuard guard(rewriter);
470 rewriter.setInsertionPoint(forallOp);
471 SmallVector<int64_t> originalBasis(availableMappingSizes);
472 bool originalBasisWasProvided = !originalBasis.empty();
473 if (!originalBasisWasProvided) {
474 originalBasis = forallMappingSizes;
475 while (originalBasis.size() < 3)
476 originalBasis.push_back(Elt: 1);
477 }
478
479 IdBuilderResult builderResult =
480 gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
481
482 // Step 4. Map the induction variables to the mappingIdOps, this may involve
483 // a permutation.
484 SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
485 IRMapping bvm;
486 for (auto [iv, dim] : llvm::zip_equal(
487 forallOp.getInductionVars(),
488 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
489 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
490 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
491 bvm.map(iv, peIdOp);
492 }
493
494 // Step 5. If the originalBasis is already known, create conditionals to
495 // predicate the region. Otherwise, the current forall determines the
496 // originalBasis and no predication occurs.
497 Value predicate;
498 if (originalBasisWasProvided) {
499 SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
500 SmallVector<int64_t> availableMappingSizes =
501 builderResult.availableMappingSizes;
502 SmallVector<Value> activeIdOps = builderResult.activeIdOps;
503 // clang-format off
504 LLVM_DEBUG(
505 llvm::interleaveComma(
506 activeMappingSizes, DBGS() << "----activeMappingSizes: ");
507 llvm::dbgs() << "\n";
508 llvm::interleaveComma(
509 availableMappingSizes, DBGS() << "----availableMappingSizes: ");
510 llvm::dbgs() << "\n";
511 llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
512 llvm::dbgs() << "\n");
513 // clang-format on
514 for (auto [activeId, activeMappingSize, availableMappingSize] :
515 llvm::zip_equal(activeIdOps, activeMappingSizes,
516 availableMappingSizes)) {
517 if (activeMappingSize > availableMappingSize) {
518 return definiteFailureHelper(
519 transformOp, forallOp,
520 "Trying to map to fewer GPU threads than loop iterations but "
521 "overprovisioning is not yet supported. "
522 "Try additional tiling of the before mapping or map to more "
523 "threads.");
524 }
525 if (activeMappingSize == availableMappingSize)
526 continue;
527 Value idx =
528 rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
529 Value tmpPredicate = rewriter.create<arith::CmpIOp>(
530 loc, arith::CmpIPredicate::ult, activeId, idx);
531 LDBG("----predicate: " << tmpPredicate);
532 predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
533 tmpPredicate)
534 : tmpPredicate;
535 }
536 }
537
538 // Step 6. Move the body of forallOp.
539 // Erase the terminator first, it will not be used.
540 rewriter.eraseOp(op: forallOp.getTerminator());
541 Block *targetBlock;
542 Block::iterator insertionPoint;
543 if (predicate) {
544 // Step 6.a. If predicated, move at the beginning.
545 auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
546 /*withElseRegion=*/false);
547 targetBlock = ifOp.thenBlock();
548 insertionPoint = ifOp.thenBlock()->begin();
549 } else {
550 // Step 6.b. Otherwise, move inline just at the rewriter insertion
551 // point.
552 targetBlock = forallOp->getBlock();
553 insertionPoint = rewriter.getInsertionPoint();
554 }
555 Block &sourceBlock = forallOp.getRegion().front();
556 targetBlock->getOperations().splice(where: insertionPoint,
557 L2&: sourceBlock.getOperations());
558
559 // Step 7. RAUW indices.
560 for (Value loopIndex : forallOp.getInductionVars()) {
561 Value threadIdx = bvm.lookup(loopIndex);
562 rewriter.replaceAllUsesWith(loopIndex, threadIdx);
563 }
564
565 // Step 8. Erase old op.
566 rewriter.eraseOp(op: forallOp);
567
568 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
569 DBGS() << "----result forallMappingSizes: ");
570 llvm::dbgs() << "\n"; llvm::interleaveComma(
571 mappingIdOps, DBGS() << "----result mappingIdOps: ");
572 llvm::dbgs() << "\n");
573
574 result = ForallRewriteResult{.mappingSizes: forallMappingSizes, .mappingIds: mappingIdOps};
575 return DiagnosedSilenceableFailure::success();
576}
577
578//===----------------------------------------------------------------------===//
579// MapForallToBlocks
580//===----------------------------------------------------------------------===//
581
582DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
583 RewriterBase &rewriter, TransformOpInterface transformOp,
584 scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
585 const GpuIdBuilder &gpuIdBuilder) {
586 LDBG("Start mapForallToBlocksImpl");
587
588 {
589 // GPU-specific verifications. There is no better place to anchor
590 // those right now: the ForallOp is target-independent and the transform
591 // op does not apply to individual ForallOp.
592 DiagnosedSilenceableFailure diag =
593 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
594 if (!diag.succeeded())
595 return diag;
596 }
597
598 Location loc = forallOp.getLoc();
599 Block *parentBlock = forallOp->getBlock();
600 Value zero;
601 {
602 // Create an early zero index value for replacements and immediately reset
603 // the insertion point.
604 OpBuilder::InsertionGuard guard(rewriter);
605 rewriter.setInsertionPointToStart(parentBlock);
606 zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
607 }
608
609 ForallRewriteResult rewriteResult;
610 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
611 rewriter, transformOp, forallOp,
612 /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
613
614 // Return if anything goes wrong, use silenceable failure as a match
615 // failure.
616 if (!diag.succeeded())
617 return diag;
618
619 // If gridDims was not provided already, set it from the return.
620 if (gridDims.empty()) {
621 gridDims = rewriteResult.mappingSizes;
622 while (gridDims.size() < 3)
623 gridDims.push_back(Elt: 1);
624 }
625 assert(gridDims.size() == 3 && "Need 3-D gridDims");
626
627 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
628 // Here, the result of mapping determines the available mapping sizes.
629 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
630 rewriteResult.mappingSizes);
631
632 return DiagnosedSilenceableFailure::success();
633}
634
635DiagnosedSilenceableFailure
636mlir::transform::gpu::findTopLevelForallOp(Operation *target,
637 scf::ForallOp &topLevelForallOp,
638 TransformOpInterface transformOp) {
639 auto walkResult = target->walk([&](scf::ForallOp forallOp) {
640 if (forallOp->getParentOfType<scf::ForallOp>())
641 return WalkResult::advance();
642 if (topLevelForallOp)
643 // TODO: Handle multiple forall if they are independent.
644 return WalkResult::interrupt();
645 topLevelForallOp = forallOp;
646 return WalkResult::advance();
647 });
648
649 if (walkResult.wasInterrupted() || !topLevelForallOp)
650 return transformOp.emitSilenceableError()
651 << "could not find a unique topLevel scf.forall";
652 return DiagnosedSilenceableFailure::success();
653}
654
655DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
656 transform::TransformRewriter &rewriter, Operation *target,
657 ApplyToEachResultList &results, transform::TransformState &state) {
658 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
659 auto transformOp = cast<TransformOpInterface>(getOperation());
660
661 if (!getGenerateGpuLaunch() && !gpuLaunch) {
662 DiagnosedSilenceableFailure diag =
663 emitSilenceableError()
664 << "Given target is not gpu.launch, set `generate_gpu_launch` "
665 "attribute";
666 diag.attachNote(target->getLoc()) << "when applied to this payload op";
667 return diag;
668 }
669
670 scf::ForallOp topLevelForallOp;
671 DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
672 target, topLevelForallOp, transformOp);
673 if (!diag.succeeded()) {
674 diag.attachNote(target->getLoc()) << "when applied to this payload op";
675 return diag;
676 }
677 assert(topLevelForallOp && "expect an scf.forall");
678
679 SmallVector<int64_t> gridDims{getGridDims()};
680 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
681 return transformOp.emitDefiniteFailure("transform require size-3 mapping");
682
683 OpBuilder::InsertionGuard guard(rewriter);
684 rewriter.setInsertionPoint(topLevelForallOp);
685
686 // Generate gpu launch here and move the forall inside
687 if (getGenerateGpuLaunch()) {
688 DiagnosedSilenceableFailure diag =
689 createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
690 if (!diag.succeeded())
691 return diag;
692
693 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
694 Operation *newForallOp = rewriter.clone(*topLevelForallOp);
695 rewriter.eraseOp(topLevelForallOp);
696 topLevelForallOp = cast<scf::ForallOp>(newForallOp);
697 }
698
699 // The BlockIdBuilder adapts to whatever is thrown at it.
700 bool useLinearMapping = false;
701 if (topLevelForallOp.getMapping()) {
702 auto mappingAttr = cast<DeviceMappingAttrInterface>(
703 topLevelForallOp.getMapping()->getValue().front());
704 useLinearMapping = mappingAttr.isLinearMapping();
705 }
706 GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
707
708 diag = mlir::transform::gpu::mapForallToBlocksImpl(
709 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
710 if (!diag.succeeded())
711 return diag;
712
713 // Set the GPU launch configuration for the grid dims late, this is
714 // subject to IR inspection.
715 diag = alterGpuLaunch(rewriter, gpuLaunch,
716 cast<TransformOpInterface>(getOperation()), gridDims[0],
717 gridDims[1], gridDims[2]);
718
719 results.push_back(gpuLaunch);
720 return diag;
721}
722
723LogicalResult transform::MapForallToBlocks::verify() {
724 if (!getGridDims().empty() && getGridDims().size() != 3) {
725 return emitOpError() << "transform requires empty or size-3 grid_dims";
726 }
727 return success();
728}
729
730//===----------------------------------------------------------------------===//
731// MapNestedForallToThreads
732//===----------------------------------------------------------------------===//
733
734static DiagnosedSilenceableFailure checkMappingSpec(
735 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
736 ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
737 int factor, bool useLinearMapping = false) {
738 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
739 auto diag = definiteFailureHelper(
740 transformOp, forallOp,
741 Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
742 std::to_string(val: factor));
743 return diag;
744 }
745 if (computeProduct(basis: numParallelIterations) * factor >
746 computeProduct(basis: blockOrGridSizes)) {
747 auto diag = definiteFailureHelper(
748 transformOp, forallOp,
749 Twine("the number of required parallel resources (blocks or "
750 "threads) ") +
751 std::to_string(val: computeProduct(basis: numParallelIterations) * factor) +
752 std::string(" overflows the number of available resources ") +
753 std::to_string(val: computeProduct(basis: blockOrGridSizes)));
754 return diag;
755 }
756 return DiagnosedSilenceableFailure::success();
757}
758
759static DiagnosedSilenceableFailure
760getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
761 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
762 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
763 auto mappingAttr = cast<DeviceMappingAttrInterface>(
764 forallOp.getMapping()->getValue().front());
765 bool useLinearMapping = mappingAttr.isLinearMapping();
766
767 // Sanity checks that may result in runtime verification errors.
768 auto numParallelIterations =
769 getConstantIntValues((forallOp.getMixedUpperBound()));
770 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
771 return definiteFailureHelper(
772 transformOp, forallOp,
773 "requires statically sized, normalized forall op");
774 }
775 int64_t factor = 1;
776 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
777 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
778 } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
779 factor = warpSize;
780 }
781 DiagnosedSilenceableFailure diag =
782 checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
783 blockSizes, factor, useLinearMapping);
784 if (!diag.succeeded())
785 return diag;
786
787 // Start mapping.
788 MLIRContext *ctx = forallOp.getContext();
789 gpuIdBuilder =
790 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
791 .Case([&](GPUWarpgroupMappingAttr) {
792 return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
793 })
794 .Case([&](GPUWarpMappingAttr) {
795 return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
796 })
797 .Case([&](GPUThreadMappingAttr) {
798 return GpuThreadIdBuilder(ctx, useLinearMapping);
799 })
800 .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
801 llvm_unreachable("unknown mapping attribute");
802 });
803 return DiagnosedSilenceableFailure::success();
804}
805
806DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
807 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
808 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
809 bool syncAfterDistribute) {
810
811 {
812 // GPU-specific verifications. There is no better place to anchor
813 // those right now: the ForallOp is target-independent and the transform
814 // op does not apply to individual ForallOp.
815 DiagnosedSilenceableFailure diag =
816 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
817 if (!diag.succeeded())
818 return diag;
819 }
820
821 GpuIdBuilder gpuIdBuilder;
822 {
823 // Try to construct the id builder, if it fails, return.
824 DiagnosedSilenceableFailure diag = getThreadIdBuilder(
825 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
826 if (!diag.succeeded())
827 return diag;
828 }
829
830 Location loc = forallOp.getLoc();
831 OpBuilder::InsertionGuard g(rewriter);
832 // Insert after to allow for syncthreads after `forall` is erased.
833 rewriter.setInsertionPointAfter(forallOp);
834 ForallRewriteResult rewriteResult;
835 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
836 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
837 if (!diag.succeeded())
838 return diag;
839 // Add a syncthreads if needed. TODO: warpsync
840 if (syncAfterDistribute)
841 rewriter.create<BarrierOp>(loc);
842
843 return DiagnosedSilenceableFailure::success();
844}
845
846DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
847 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
848 Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
849 bool syncAfterDistribute) {
850 LDBG("Start mapNestedForallToThreadsImpl");
851 if (blockDims.size() != 3) {
852 return definiteFailureHelper(transformOp, target,
853 message: "requires size-3 thread mapping");
854 }
855
856 // Create an early zero index value for replacements.
857 Location loc = target->getLoc();
858 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
859 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
860 WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
861 diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
862 rewriter, transformOp, forallOp, blockDims, warpSize,
863 syncAfterDistribute);
864 if (diag.isDefiniteFailure())
865 return WalkResult::interrupt();
866 if (diag.succeeded())
867 return WalkResult::skip();
868 return WalkResult::advance();
869 });
870 if (walkResult.wasInterrupted())
871 return diag;
872
873 // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
874 // Here, the result of mapping determines the available mapping sizes.
875 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
876 blockDims);
877
878 return DiagnosedSilenceableFailure::success();
879}
880
881DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
882 transform::TransformRewriter &rewriter, Operation *target,
883 ApplyToEachResultList &results, TransformState &state) {
884 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
885 auto transformOp = cast<TransformOpInterface>(getOperation());
886
887 // Basic high-level verifications.
888 if (!gpuLaunch)
889 return emitSilenceableError() << "Given target is not a gpu.launch";
890
891 // Mapping to block ids.
892 SmallVector<int64_t> blockDims{getBlockDims()};
893 DiagnosedSilenceableFailure diag =
894 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
895 blockDims[0], blockDims[1], blockDims[2]);
896 if (diag.isSilenceableFailure()) {
897 diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
898 return diag;
899 }
900
901 // Set the GPU launch configuration for the block dims early, this is not
902 // subject to IR inspection.
903 diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
904 std::nullopt, std::nullopt, blockDims[0], blockDims[1],
905 blockDims[2]);
906
907 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
908 diag =
909 mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
910 getWarpSize(), getSyncAfterDistribute());
911
912 results.push_back(gpuLaunch.getOperation());
913 return diag;
914}
915
916//===----------------------------------------------------------------------===//
917// Transform op registration
918//===----------------------------------------------------------------------===//
919
920namespace {
921/// Registers new ops and declares PDL as dependent dialect since the
922/// additional ops are using PDL types for operands and results.
923class GPUTransformDialectExtension
924 : public transform::TransformDialectExtension<
925 GPUTransformDialectExtension> {
926public:
927 GPUTransformDialectExtension() {
928 declareGeneratedDialect<scf::SCFDialect>();
929 declareGeneratedDialect<arith::ArithDialect>();
930 declareGeneratedDialect<GPUDialect>();
931 registerTransformOps<
932#define GET_OP_LIST
933#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
934 >();
935 }
936};
937} // namespace
938
939#define GET_OP_CLASSES
940#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
941
942void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
943 registry.addExtensions<GPUTransformDialectExtension>();
944}
945

source code of mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp