| 1 | //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Analysis/SliceAnalysis.h" |
| 12 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 14 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| 23 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 29 | #include "mlir/IR/AffineExpr.h" |
| 30 | #include "mlir/IR/BuiltinTypes.h" |
| 31 | #include "mlir/IR/Value.h" |
| 32 | #include "llvm/ADT/ArrayRef.h" |
| 33 | |
| 34 | using namespace mlir; |
| 35 | using namespace mlir::linalg; |
| 36 | using namespace mlir::nvgpu; |
| 37 | using namespace mlir::NVVM; |
| 38 | using namespace mlir::transform; |
| 39 | |
| 40 | #define DEBUG_TYPE "nvgpu-transforms" |
| 41 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 42 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 43 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
| 44 | |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | // Apply...ConversionPatternsOp |
| 47 | //===----------------------------------------------------------------------===// |
| 48 | |
| 49 | void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( |
| 50 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 51 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 52 | /// device-side async tokens cannot be materialized in nvvm. We just |
| 53 | /// convert them to a dummy i32 type in order to easily drop them during |
| 54 | /// conversion. |
| 55 | populateGpuMemorySpaceAttributeConversions( |
| 56 | llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned { |
| 57 | switch (space) { |
| 58 | case gpu::AddressSpace::Global: |
| 59 | return static_cast<unsigned>( |
| 60 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
| 61 | case gpu::AddressSpace::Workgroup: |
| 62 | return static_cast<unsigned>( |
| 63 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
| 64 | case gpu::AddressSpace::Private: |
| 65 | return 0; |
| 66 | } |
| 67 | llvm_unreachable("unknown address space enum value" ); |
| 68 | return 0; |
| 69 | }); |
| 70 | llvmTypeConverter.addConversion( |
| 71 | [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
| 72 | return llvmTypeConverter.convertType( |
| 73 | IntegerType::get(type.getContext(), 32)); |
| 74 | }); |
| 75 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { |
| 76 | return llvmTypeConverter.convertType( |
| 77 | IntegerType::get(type.getContext(), 64)); |
| 78 | }); |
| 79 | llvmTypeConverter.addConversion( |
| 80 | [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
| 81 | Type elemType = type.getFragmented().getElementType(); |
| 82 | int64_t sizeM = type.getFragmented().getDimSize(0); |
| 83 | int64_t sizeN = type.getFragmented().getDimSize(1); |
| 84 | |
| 85 | unsigned numMembers; |
| 86 | if (elemType.isF32() || elemType.isInteger(32)) |
| 87 | numMembers = sizeN / 2; |
| 88 | else if (elemType.isF16()) |
| 89 | numMembers = sizeN / 4; |
| 90 | else |
| 91 | llvm_unreachable("unsupported type for warpgroup accumulator" ); |
| 92 | |
| 93 | SmallVector<Type> innerStructBody; |
| 94 | for (unsigned i = 0; i < numMembers; i++) |
| 95 | innerStructBody.push_back(elemType); |
| 96 | auto innerStructType = LLVM::LLVMStructType::getLiteral( |
| 97 | type.getContext(), innerStructBody); |
| 98 | |
| 99 | SmallVector<Type> structBody; |
| 100 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
| 101 | structBody.push_back(innerStructType); |
| 102 | |
| 103 | auto convertedType = |
| 104 | LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
| 105 | return llvmTypeConverter.convertType(convertedType); |
| 106 | }); |
| 107 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { |
| 108 | return llvmTypeConverter.convertType( |
| 109 | getMBarrierMemrefType(type.getContext(), type)); |
| 110 | }); |
| 111 | llvmTypeConverter.addConversion( |
| 112 | [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
| 113 | return llvmTypeConverter.convertType( |
| 114 | IntegerType::get(type.getContext(), 64)); |
| 115 | }); |
| 116 | llvmTypeConverter.addConversion( |
| 117 | [&](nvgpu::TensorMapDescriptorType type) -> Type { |
| 118 | return LLVM::LLVMPointerType::get(type.getContext()); |
| 119 | }); |
| 120 | populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); |
| 121 | } |
| 122 | |
| 123 | LogicalResult |
| 124 | transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
| 125 | transform::TypeConverterBuilderOpInterface builder) { |
| 126 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 127 | return emitOpError("expected LLVMTypeConverter" ); |
| 128 | return success(); |
| 129 | } |
| 130 | |
| 131 | //===---------------------------------------------------------------------===// |
| 132 | // CreateAsyncGroupsOp |
| 133 | //===---------------------------------------------------------------------===// |
| 134 | |
| 135 | void transform::CreateAsyncGroupsOp::getEffects( |
| 136 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 137 | transform::consumesHandle(getTargetMutable(), effects); |
| 138 | transform::producesHandle(getOperation()->getOpResults(), effects); |
| 139 | transform::modifiesPayload(effects); |
| 140 | } |
| 141 | |
| 142 | DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( |
| 143 | TransformRewriter &rewriter, Operation *target, |
| 144 | ApplyToEachResultList &results, TransformState &state) { |
| 145 | nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); |
| 146 | results.push_back(target); |
| 147 | return DiagnosedSilenceableFailure::success(); |
| 148 | } |
| 149 | |
| 150 | //===----------------------------------------------------------------------===// |
| 151 | // PipelineSharedMemoryCopiesOp |
| 152 | //===----------------------------------------------------------------------===// |
| 153 | |
| 154 | /// Returns true if the given type has the default memory space. |
| 155 | static bool hasDefaultMemorySpace(BaseMemRefType type) { |
| 156 | return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; |
| 157 | } |
| 158 | |
| 159 | /// Returns true if the given type has the shared (workgroup) memory space. |
| 160 | static bool hasSharedMemorySpace(BaseMemRefType type) { |
| 161 | auto space = |
| 162 | dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
| 163 | return space && |
| 164 | space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); |
| 165 | } |
| 166 | |
| 167 | /// Returns the value produced by a load from the default memory space. Returns |
| 168 | /// null if the operation is not such a load. |
| 169 | static Value getValueLoadedFromGlobal(Operation *op) { |
| 170 | // TODO: consider an interface or leveraging the memory effects interface. |
| 171 | auto load = dyn_cast<vector::TransferReadOp>(op); |
| 172 | if (!load) |
| 173 | return nullptr; |
| 174 | |
| 175 | auto loadType = dyn_cast<MemRefType>(load.getBase().getType()); |
| 176 | if (!loadType || !hasDefaultMemorySpace(loadType)) |
| 177 | return nullptr; |
| 178 | return load; |
| 179 | } |
| 180 | |
| 181 | /// Returns true if the operation is storing the given value into shared memory. |
| 182 | static bool isStoreToShared(Operation *op, Value v) { |
| 183 | // TOD: consider an interface or leveraging the memory effects interface. |
| 184 | auto store = dyn_cast<vector::TransferWriteOp>(op); |
| 185 | if (!store || store.getVector() != v) |
| 186 | return false; |
| 187 | |
| 188 | auto storeType = dyn_cast<MemRefType>(store.getBase().getType()); |
| 189 | return storeType || hasSharedMemorySpace(storeType); |
| 190 | } |
| 191 | |
| 192 | /// Returns true if the operation is a load from the default memory space the |
| 193 | /// result of which is only stored into the shared memory space. |
| 194 | static bool isLoadFromGlobalStoredToShared(Operation *op) { |
| 195 | Value loaded = getValueLoadedFromGlobal(op); |
| 196 | if (!loaded || !loaded.hasOneUse()) |
| 197 | return false; |
| 198 | |
| 199 | return isStoreToShared(op: *loaded.getUsers().begin(), v: loaded); |
| 200 | } |
| 201 | |
| 202 | /// Populate `ops` with the set of operations that belong to the stage 0 of the |
| 203 | /// pipelined version of the given loop when pipelining copies to shared memory. |
| 204 | /// Specifically, this collects: |
| 205 | /// |
| 206 | /// 1. all loads from global memory, both sync and async; |
| 207 | /// 2. the barriers for async loads. |
| 208 | /// |
| 209 | /// In particular, barriers are omitted if they do not dominate at least one |
| 210 | /// async load for which there is not yet a barrier. |
| 211 | static LogicalResult |
| 212 | collectStage0PipeliningOps(scf::ForOp forOp, |
| 213 | llvm::SmallPtrSet<Operation *, 16> &ops) { |
| 214 | |
| 215 | llvm::SmallPtrSet<Operation *, 4> barriers; |
| 216 | for (Operation &op : *forOp.getBody()) { |
| 217 | // Bail on nested ops for now. |
| 218 | if (op.getNumRegions() > 0) |
| 219 | return failure(); |
| 220 | |
| 221 | if (isa<gpu::BarrierOp>(op)) { |
| 222 | barriers.insert(&op); |
| 223 | continue; |
| 224 | } |
| 225 | |
| 226 | if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { |
| 227 | ops.insert(&op); |
| 228 | ops.insert(std::make_move_iterator(barriers.begin()), |
| 229 | std::make_move_iterator(barriers.end())); |
| 230 | assert(barriers.empty() && |
| 231 | "expected to have moved the barriers into another set" ); |
| 232 | continue; |
| 233 | } |
| 234 | |
| 235 | if (isLoadFromGlobalStoredToShared(&op)) { |
| 236 | ops.insert(&op); |
| 237 | continue; |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | return success(); |
| 242 | } |
| 243 | |
| 244 | /// Hook for the loop pipeliner that sets the "num groups in flight" attribute |
| 245 | /// of async wait operations corresponding to pipelined shared memory copies. |
| 246 | // TODO: this currently assumes that there are no groups that could be in flight |
| 247 | // in the existing code. |
| 248 | static void |
| 249 | setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, |
| 250 | scf::PipeliningOption::PipelinerPart part, |
| 251 | unsigned iteration, unsigned depth) { |
| 252 | // Based on the order of copies within the loop we need to set the number |
| 253 | // of copies in flight, unless it is already set. |
| 254 | auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); |
| 255 | if (!waitOp || waitOp.getNumGroups()) |
| 256 | return; |
| 257 | |
| 258 | int numGroupInFlight = 0; |
| 259 | if (part == scf::PipeliningOption::PipelinerPart::Kernel || |
| 260 | part == scf::PipeliningOption::PipelinerPart::Prologue) { |
| 261 | numGroupInFlight = depth - 1; |
| 262 | } else { |
| 263 | // By construction there should be no wait op in the prologue as all the |
| 264 | // wait should be in the last stage. |
| 265 | assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); |
| 266 | // Based on the schedule we pick we know how many groups are in flight for |
| 267 | // each iteration of the epilogue. |
| 268 | numGroupInFlight = depth - 1 - iteration; |
| 269 | } |
| 270 | waitOp.setNumGroups(numGroupInFlight); |
| 271 | } |
| 272 | |
| 273 | /// Hook for the loop pipeliner that populates `ops` with the stage information |
| 274 | /// as follows: |
| 275 | /// |
| 276 | /// - operations in `stage0Ops` (typically loads from global memory and |
| 277 | /// related barriers) are at stage 0; |
| 278 | /// - operations in the backward slice of any stage0Ops are all at stage 0; |
| 279 | /// - other operations are at stage `depth`; |
| 280 | /// - the internal order of the pipelined loop has ops at stage `depth` first, |
| 281 | /// then those at stage 0, with relative order within each group preserved. |
| 282 | /// |
| 283 | static void getPipelineStages( |
| 284 | scf::ForOp forOp, |
| 285 | std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages, |
| 286 | unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { |
| 287 | SetVector<Operation *> dependencies; |
| 288 | BackwardSliceOptions options([&](Operation *visited) { |
| 289 | return visited->getBlock() == forOp.getBody(); |
| 290 | }); |
| 291 | options.inclusive = true; |
| 292 | for (Operation &op : forOp.getBody()->getOperations()) { |
| 293 | if (stage0Ops.contains(&op)) { |
| 294 | LogicalResult result = getBackwardSlice(&op, &dependencies, options); |
| 295 | assert(result.succeeded() && "expected a backward slice" ); |
| 296 | (void)result; |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | for (Operation &op : forOp.getBody()->getOperations()) { |
| 301 | if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op)) |
| 302 | opsWithPipelineStages.emplace_back(&op, depth); |
| 303 | } |
| 304 | for (Operation &op : forOp.getBody()->getOperations()) { |
| 305 | if (dependencies.contains(&op)) |
| 306 | opsWithPipelineStages.emplace_back(&op, 0); |
| 307 | } |
| 308 | } |
| 309 | |
| 310 | /// Hook for the loop pipeliner. Replaces op with a predicated version and |
| 311 | /// returns the resulting operation. Returns the original op if the predication |
| 312 | /// isn't necessary for the given op. Returns null if predication is needed but |
| 313 | /// not supported. |
| 314 | static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, |
| 315 | Operation *op, Value predicate) { |
| 316 | // Some operations may be fine to execute "speculatively" more times than the |
| 317 | // original number of iterations, in particular side-effect free operations |
| 318 | // and barriers, even if they cannot be predicated. |
| 319 | if (isMemoryEffectFree(op) || |
| 320 | isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, |
| 321 | nvgpu::DeviceAsyncWaitOp>(op)) { |
| 322 | return op; |
| 323 | } |
| 324 | |
| 325 | // Otherwise, only async copies can currently be predicated. |
| 326 | auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); |
| 327 | if (!asyncCopyOp) |
| 328 | return nullptr; |
| 329 | |
| 330 | // Create srcElement Value based on `predicate`. The next lines generate |
| 331 | // the following code: |
| 332 | // |
| 333 | // srcElement = (pred) ? prevSrcElements : 0; |
| 334 | // |
| 335 | Location loc = asyncCopyOp->getLoc(); |
| 336 | Value dstElements = |
| 337 | rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr()); |
| 338 | Value originalSrcElement = |
| 339 | asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; |
| 340 | Value c0Index = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 341 | auto srcElements = rewriter.create<arith::SelectOp>( |
| 342 | loc, predicate, originalSrcElement, c0Index); |
| 343 | auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>( |
| 344 | loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), |
| 345 | asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), |
| 346 | asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, |
| 347 | UnitAttr()); |
| 348 | rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); |
| 349 | return asyncCopyZeroFillOp; |
| 350 | } |
| 351 | |
| 352 | /// Applies loop pipelining with the given depth to the given loop so that |
| 353 | /// copies into the shared memory are pipelined. Doesn't affect other loops. |
| 354 | /// Returns a pair containing the error state and the pipelined op, the latter |
| 355 | /// being null in case of any failure. The error state contains a definite error |
| 356 | /// if the IR has been modified and a silenceable error otherwise. |
| 357 | static std::tuple<DiagnosedSilenceableFailure, scf::ForOp> |
| 358 | pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, |
| 359 | bool epiloguePeeling) { |
| 360 | llvm::SmallPtrSet<Operation *, 16> stage0Ops; |
| 361 | if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { |
| 362 | return std::make_tuple( |
| 363 | emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining" ), |
| 364 | scf::ForOp()); |
| 365 | } |
| 366 | if (stage0Ops.empty()) { |
| 367 | return std::make_tuple( |
| 368 | emitSilenceableFailure(forOp, "no shared memory copy" ), scf::ForOp()); |
| 369 | } |
| 370 | |
| 371 | scf::PipeliningOption options; |
| 372 | unsigned maxDepth = depth; |
| 373 | auto setAnnotation = [&](Operation *op, |
| 374 | scf::PipeliningOption::PipelinerPart part, |
| 375 | unsigned iteration) { |
| 376 | return setAsyncWaitGroupsInFlight(builder&: rewriter, op, part, iteration, depth: maxDepth); |
| 377 | }; |
| 378 | options.getScheduleFn = |
| 379 | [&](scf::ForOp schedulingFor, |
| 380 | std::vector<std::pair<Operation *, unsigned>> &ops) { |
| 381 | if (schedulingFor != forOp) |
| 382 | return; |
| 383 | return getPipelineStages(forOp, ops, maxDepth, stage0Ops); |
| 384 | }; |
| 385 | options.annotateFn = setAnnotation; |
| 386 | if (!epiloguePeeling) { |
| 387 | options.peelEpilogue = false; |
| 388 | options.predicateFn = replaceOpWithPredicatedOp; |
| 389 | } |
| 390 | |
| 391 | OpBuilder::InsertionGuard guard(rewriter); |
| 392 | rewriter.setInsertionPoint(forOp); |
| 393 | bool modifiedIR; |
| 394 | FailureOr<scf::ForOp> maybePipelined = |
| 395 | pipelineForLoop(rewriter, forOp, options, &modifiedIR); |
| 396 | if (succeeded(maybePipelined)) { |
| 397 | return std::make_tuple(DiagnosedSilenceableFailure::success(), |
| 398 | *maybePipelined); |
| 399 | } |
| 400 | return std::make_tuple( |
| 401 | modifiedIR |
| 402 | ? DiagnosedSilenceableFailure::definiteFailure() |
| 403 | : emitSilenceableFailure(forOp, "pipelining preconditions failed" ), |
| 404 | scf::ForOp()); |
| 405 | } |
| 406 | |
| 407 | DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( |
| 408 | TransformRewriter &rewriter, scf::ForOp forOp, |
| 409 | ApplyToEachResultList &results, TransformState &state) { |
| 410 | auto [diag, pipelined] = pipelineForSharedCopies( |
| 411 | rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue()); |
| 412 | if (diag.succeeded()) { |
| 413 | results.push_back(pipelined); |
| 414 | return DiagnosedSilenceableFailure::success(); |
| 415 | } |
| 416 | if (diag.isDefiniteFailure()) { |
| 417 | auto diag = emitDefiniteFailure("irreversible pipelining failure" ); |
| 418 | if (!getPeelEpilogue()) { |
| 419 | diag.attachNote(forOp->getLoc()) << "couldn't predicate?" ; |
| 420 | diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName(); |
| 421 | } |
| 422 | return diag; |
| 423 | } |
| 424 | |
| 425 | return std::move(diag); |
| 426 | } |
| 427 | |
| 428 | //===----------------------------------------------------------------------===// |
| 429 | // RewriteMatmulAsMmaSyncOp |
| 430 | //===----------------------------------------------------------------------===// |
| 431 | |
| 432 | /// Helper struct to encode a pair of row/column indexings in the form of |
| 433 | /// affine expressions. |
| 434 | struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { |
| 435 | RowColIndexing(AffineExpr row, AffineExpr col) |
| 436 | : std::pair<AffineExpr, AffineExpr>(row, col) {} |
| 437 | |
| 438 | AffineExpr row() const { return first; }; |
| 439 | AffineExpr col() const { return second; }; |
| 440 | |
| 441 | void print(llvm::raw_ostream &os) const { |
| 442 | os << "- indexing: " << first << ", " << second; |
| 443 | } |
| 444 | }; |
| 445 | |
| 446 | /// Helper struct to provide a simple mapping from matmul operations to the |
| 447 | /// corresponding mma.sync operation. This is constrained to the case where the |
| 448 | /// matmul matches the mma.sync operation 1-1. |
| 449 | struct MmaSyncBuilder { |
| 450 | MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) |
| 451 | : b(b), loc(loc), laneId(laneId) {} |
| 452 | |
| 453 | using IndexCalculator = |
| 454 | std::function<SmallVector<RowColIndexing>(MLIRContext *)>; |
| 455 | |
| 456 | /// Create the mma.sync operation corresponding to `linalgOp` along with all |
| 457 | /// the supporting load/store and vector operations. |
| 458 | FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp); |
| 459 | |
| 460 | private: |
| 461 | struct MmaSyncInfo { |
| 462 | std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns; |
| 463 | std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>> |
| 464 | vectorShapes; |
| 465 | SmallVector<int64_t> mmaShape; |
| 466 | bool tf32Enabled; |
| 467 | }; |
| 468 | |
| 469 | /// Return the specific index calculator for the given `linalgOp` or failure |
| 470 | /// if the op is not supported. This is the toplevel switch that should just |
| 471 | /// be Tablegen'd in the future. |
| 472 | FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape, |
| 473 | TypeRange elementalTypes); |
| 474 | |
| 475 | //===--------------------------------------------------------------------===// |
| 476 | // Instruction-specific row, column indexing expression builders. |
| 477 | // These should all be declaratively specified via Tablegen in the future. |
| 478 | // The Tablegen specification should be as straightforward as possible to |
| 479 | // only model the existing size and type combinations. |
| 480 | //===--------------------------------------------------------------------===// |
| 481 | // |
| 482 | // TODO: Tablegen all this. |
| 483 | //===--------------------------------------------------------------------===// |
| 484 | // m16n8k4 tf32 case. |
| 485 | //===--------------------------------------------------------------------===// |
| 486 | /// From the NVIDIA doc: |
| 487 | /// groupID = %laneid >> 2 |
| 488 | /// threadIDInGroup = %laneid % 4 |
| 489 | /// row = groupID for a0 |
| 490 | /// groupID + 8 for a1 |
| 491 | /// col = threadIDInGroup |
| 492 | static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) { |
| 493 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 494 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 495 | AffineExpr threadIDInGroup = dim % 4; |
| 496 | return {RowColIndexing{groupID, threadIDInGroup}, |
| 497 | RowColIndexing{groupID + 8, threadIDInGroup}}; |
| 498 | } |
| 499 | |
| 500 | /// From the NVIDIA doc: |
| 501 | /// groupID = %laneid >> 2 |
| 502 | /// threadIDInGroup = %laneid % 4 |
| 503 | /// row = threadIDInGroup |
| 504 | /// col = groupID |
| 505 | static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) { |
| 506 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 507 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 508 | AffineExpr threadIDInGroup = dim % 4; |
| 509 | return {RowColIndexing{threadIDInGroup, groupID}}; |
| 510 | } |
| 511 | |
| 512 | /// From the NVIDIA doc: |
| 513 | /// groupID = %laneid >> 2 |
| 514 | /// threadIDInGroup = %laneid % 4 |
| 515 | /// row = groupID for c0 and c1 |
| 516 | /// groupID + 8 for c2 and c3 |
| 517 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
| 518 | static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) { |
| 519 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 520 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 521 | AffineExpr threadIDInGroup = dim % 4; |
| 522 | return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, |
| 523 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, |
| 524 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, |
| 525 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; |
| 526 | } |
| 527 | |
| 528 | //===--------------------------------------------------------------------===// |
| 529 | // m16n8k16 f16 case. |
| 530 | //===--------------------------------------------------------------------===// |
| 531 | /// From the NVIDIA doc: |
| 532 | /// groupID = %laneid >> 2 |
| 533 | /// threadIDInGroup = %laneid % 4 |
| 534 | /// |
| 535 | /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 |
| 536 | /// groupID + 8 Otherwise |
| 537 | /// |
| 538 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 |
| 539 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 |
| 540 | static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) { |
| 541 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 542 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 543 | AffineExpr threadIDInGroup = dim % 4; |
| 544 | // clang-format off |
| 545 | return { |
| 546 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
| 547 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
| 548 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
| 549 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 |
| 550 | RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 |
| 551 | RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 |
| 552 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 |
| 553 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 |
| 554 | }; |
| 555 | // clang-format on |
| 556 | } |
| 557 | |
| 558 | /// From the NVIDIA doc: |
| 559 | /// groupID = %laneid >> 2 |
| 560 | /// threadIDInGroup = %laneid % 4 |
| 561 | /// |
| 562 | /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 |
| 563 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 |
| 564 | /// |
| 565 | /// col = groupID |
| 566 | static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) { |
| 567 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 568 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 569 | AffineExpr threadIDInGroup = dim % 4; |
| 570 | // clang-format off |
| 571 | return { |
| 572 | RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 |
| 573 | RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 |
| 574 | RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 |
| 575 | RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 |
| 576 | }; |
| 577 | // clang-format on |
| 578 | } |
| 579 | |
| 580 | /// From the NVIDIA doc: |
| 581 | /// groupID = %laneid >> 2 |
| 582 | /// threadIDInGroup = %laneid % 4 |
| 583 | /// |
| 584 | /// row = groupID for ci where i < 2 |
| 585 | /// groupID + 8 for ci where i >= 2 |
| 586 | /// |
| 587 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
| 588 | static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) { |
| 589 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
| 590 | AffineExpr groupID = dim.floorDiv(v: 4); |
| 591 | AffineExpr threadIDInGroup = dim % 4; |
| 592 | // clang-format off |
| 593 | return { |
| 594 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
| 595 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
| 596 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
| 597 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 |
| 598 | }; |
| 599 | // clang-format on |
| 600 | } |
| 601 | |
| 602 | //===--------------------------------------------------------------------===// |
| 603 | /// Helper functions to create customizable load and stores operations. The |
| 604 | /// specific shapes of each MMA instruction are passed via the |
| 605 | /// IndexCalculator callback. |
| 606 | //===--------------------------------------------------------------------===// |
| 607 | /// Build a list of memref.load operations indexed at `(row, col)` indices |
| 608 | /// that make sense for a particular MMA instruction and specified via the |
| 609 | /// IndexCalculator callback. |
| 610 | SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc, |
| 611 | OpFoldResult laneId, Value memref, |
| 612 | const IndexCalculator &indexFn); |
| 613 | |
| 614 | /// Perform a distributed load of a vector operand of `vectorShape` for a |
| 615 | /// particular MMA instruction whose `(row, col)` indices are specified via |
| 616 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
| 617 | /// data that makes sense for the particular MMA operation. |
| 618 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
| 619 | /// could also be flattened in the future if needed for simplification. |
| 620 | Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, |
| 621 | OpFoldResult laneId, Value memref, |
| 622 | IndexCalculator indexFn, |
| 623 | ArrayRef<int64_t> vectorShape); |
| 624 | |
| 625 | /// Build a list of memref.store operations indexed at `(row, col)` indices |
| 626 | /// that make sense for a particular MMA instruction and specified via the |
| 627 | /// IndexCalculator callback. |
| 628 | SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc, |
| 629 | ValueRange toStore, |
| 630 | OpFoldResult laneId, Value memref, |
| 631 | const IndexCalculator &indexFn); |
| 632 | |
| 633 | /// Perform a distributed store of a vector operand of `vectorShape` for a |
| 634 | /// particular MMA instruction whose `(row, col)` indices are specified via |
| 635 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
| 636 | /// data that makes sense for the particular MMA operation. |
| 637 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
| 638 | /// could also be flattened in the future if needed for simplification. |
| 639 | SmallVector<Operation *> buildMmaSyncMemRefStoreOperand( |
| 640 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
| 641 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape); |
| 642 | |
| 643 | OpBuilder &b; |
| 644 | Location loc; |
| 645 | OpFoldResult laneId; |
| 646 | }; |
| 647 | |
| 648 | //===--------------------------------------------------------------------===// |
| 649 | /// Helper functions to create customizable load and stores operations. The |
| 650 | /// specific shapes of each MMA instruction are passed via the |
| 651 | /// IndexCalculator callback. |
| 652 | //===--------------------------------------------------------------------===// |
| 653 | |
| 654 | template <typename ApplyFn, typename ReduceFn> |
| 655 | static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, |
| 656 | ReduceFn reduceFn) { |
| 657 | VectorType vectorType = cast<VectorType>(vector.getType()); |
| 658 | auto vectorShape = vectorType.getShape(); |
| 659 | auto strides = computeStrides(vectorShape); |
| 660 | for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { |
| 661 | auto indices = delinearize(idx, strides); |
| 662 | reduceFn(applyFn(vector, idx, indices), idx, indices); |
| 663 | } |
| 664 | } |
| 665 | |
| 666 | SmallVector<Value> |
| 667 | MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, |
| 668 | OpFoldResult laneId, Value memref, |
| 669 | const IndexCalculator &indexFn) { |
| 670 | auto aff = [&](AffineExpr e) { |
| 671 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
| 672 | }; |
| 673 | SmallVector<Value> res; |
| 674 | SmallVector<RowColIndexing> indexings = indexFn(b.getContext()); |
| 675 | for (auto indexing : indexings) { |
| 676 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
| 677 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
| 678 | auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col}); |
| 679 | res.push_back(Elt: load); |
| 680 | } |
| 681 | return res; |
| 682 | } |
| 683 | |
| 684 | Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( |
| 685 | OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, |
| 686 | IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
| 687 | auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn: std::move(indexFn)); |
| 688 | |
| 689 | Type elementType = getElementTypeOrSelf(type: memref.getType()); |
| 690 | auto vt = VectorType::get(vectorShape, elementType); |
| 691 | Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); |
| 692 | foreachIndividualVectorElement( |
| 693 | vector: res, |
| 694 | /*applyFn=*/ |
| 695 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| 696 | return loads[linearIdx]; |
| 697 | }, |
| 698 | /*reduceFn=*/ |
| 699 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| 700 | res = b.create<vector::InsertOp>(loc, v, res, indices); |
| 701 | }); |
| 702 | |
| 703 | return res; |
| 704 | } |
| 705 | |
| 706 | SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores( |
| 707 | OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, |
| 708 | Value memref, const IndexCalculator &indexFn) { |
| 709 | auto aff = [&](AffineExpr e) { |
| 710 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
| 711 | }; |
| 712 | SmallVector<Operation *> res; |
| 713 | for (auto [indexing, val] : |
| 714 | llvm::zip_equal(t: indexFn(b.getContext()), u&: toStore)) { |
| 715 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
| 716 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
| 717 | Operation *store = |
| 718 | b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col}); |
| 719 | res.push_back(Elt: store); |
| 720 | } |
| 721 | return res; |
| 722 | } |
| 723 | |
| 724 | SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( |
| 725 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
| 726 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
| 727 | SmallVector<Value> toStore; |
| 728 | toStore.reserve(N: 32); |
| 729 | foreachIndividualVectorElement( |
| 730 | vector: vectorToStore, |
| 731 | /*applyFn=*/ |
| 732 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| 733 | return b.create<vector::ExtractOp>(loc, vectorToStore, indices); |
| 734 | }, |
| 735 | /*reduceFn=*/ |
| 736 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| 737 | toStore.push_back(Elt: v); |
| 738 | }); |
| 739 | return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn: std::move(indexFn)); |
| 740 | } |
| 741 | |
| 742 | static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, |
| 743 | SmallVector<int64_t>> |
| 744 | makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, |
| 745 | ArrayRef<int64_t> res) { |
| 746 | SmallVector<int64_t> vlhs(lhs); |
| 747 | SmallVector<int64_t> vrhs(rhs); |
| 748 | SmallVector<int64_t> vres(res); |
| 749 | return std::make_tuple(args&: vlhs, args&: vrhs, args&: vres); |
| 750 | } |
| 751 | |
| 752 | FailureOr<MmaSyncBuilder::MmaSyncInfo> |
| 753 | MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape, |
| 754 | TypeRange elementalTypes) { |
| 755 | // TODO: Tablegen all this. |
| 756 | Type f16 = b.getF16Type(); |
| 757 | Type f32 = b.getF32Type(); |
| 758 | if (opShape == ArrayRef<int64_t>{16, 8, 4} && |
| 759 | elementalTypes == TypeRange{f32, f32, f32}) { |
| 760 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k4tf32Lhs, |
| 761 | args: &MmaSyncBuilder::m16n8k4tf32Rhs, |
| 762 | args: &MmaSyncBuilder::m16n8k4tf32Res), |
| 763 | .vectorShapes: makeVectorShapes(lhs: {2, 1}, rhs: {1, 1}, res: {2, 2}), |
| 764 | .mmaShape: SmallVector<int64_t>{opShape}, |
| 765 | /*tf32Enabled=*/true}; |
| 766 | } |
| 767 | // This is the version with f16 accumulation. |
| 768 | // TODO: version with f32 accumulation. |
| 769 | if (opShape == ArrayRef<int64_t>{16, 8, 16} && |
| 770 | elementalTypes == TypeRange{f16, f16, f16}) { |
| 771 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k16f16Lhs, |
| 772 | args: &MmaSyncBuilder::m16n8k16f16Rhs, |
| 773 | args: &MmaSyncBuilder::m16n8k16f16Res), |
| 774 | .vectorShapes: makeVectorShapes(lhs: {4, 2}, rhs: {2, 2}, res: {2, 2}), |
| 775 | .mmaShape: SmallVector<int64_t>{opShape}, |
| 776 | /*tf32Enabled=*/false}; |
| 777 | } |
| 778 | return failure(); |
| 779 | } |
| 780 | |
| 781 | FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { |
| 782 | Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); |
| 783 | Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); |
| 784 | Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); |
| 785 | assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 && |
| 786 | "expected lhs to be a 2D memref" ); |
| 787 | assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 && |
| 788 | "expected rhs to be a 2D memref" ); |
| 789 | assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 && |
| 790 | "expected res to be a 2D memref" ); |
| 791 | |
| 792 | int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0]; |
| 793 | int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1]; |
| 794 | int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1]; |
| 795 | Type lhsType = getElementTypeOrSelf(type: lhsMemRef.getType()); |
| 796 | Type rhsType = getElementTypeOrSelf(type: rhsMemRef.getType()); |
| 797 | Type resType = getElementTypeOrSelf(type: resMemRef.getType()); |
| 798 | |
| 799 | FailureOr<MmaSyncInfo> maybeInfo = |
| 800 | getIndexCalculators(opShape: {m, n, k}, elementalTypes: {lhsType, rhsType, resType}); |
| 801 | if (failed(Result: maybeInfo)) |
| 802 | return failure(); |
| 803 | |
| 804 | MmaSyncInfo info = *maybeInfo; |
| 805 | auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; |
| 806 | auto [lhsShape, rhsShape, resShape] = info.vectorShapes; |
| 807 | Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: lhsMemRef, |
| 808 | indexFn: lhsIndexFn, vectorShape: lhsShape); |
| 809 | Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: rhsMemRef, |
| 810 | indexFn: rhsIndexFn, vectorShape: rhsShape); |
| 811 | Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: resMemRef, |
| 812 | indexFn: resIndexFn, vectorShape: resShape); |
| 813 | res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape, |
| 814 | info.tf32Enabled); |
| 815 | buildMmaSyncMemRefStoreOperand(b, loc, vectorToStore: res, laneId, memref: resMemRef, indexFn: resIndexFn, |
| 816 | vectorShape: resShape); |
| 817 | return res.getDefiningOp(); |
| 818 | } |
| 819 | |
| 820 | DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( |
| 821 | transform::TransformRewriter &rewriter, LinalgOp linalgOp, |
| 822 | transform::ApplyToEachResultList &results, |
| 823 | transform::TransformState &state) { |
| 824 | bool fail = true; |
| 825 | // TODO: more robust detection of matmulOp, with transposes etc. |
| 826 | if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { |
| 827 | // Check to not let go the matmul with extended semantic, through this |
| 828 | // transform. |
| 829 | if (linalgOp.hasUserDefinedMaps()) { |
| 830 | return emitSilenceableError() |
| 831 | << "only matmul ops with non-extended semantics are supported" ; |
| 832 | } |
| 833 | Location loc = linalgOp.getLoc(); |
| 834 | // TODO: more robust computation of laneId, for now assume a single warp. |
| 835 | Value laneId = rewriter.create<gpu::ThreadIdOp>( |
| 836 | loc, rewriter.getIndexType(), gpu::Dimension::x); |
| 837 | if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) |
| 838 | fail = false; |
| 839 | } |
| 840 | |
| 841 | if (fail) { |
| 842 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 843 | << "unsupported target op: " << linalgOp; |
| 844 | diag.attachNote(linalgOp->getLoc()) << "target op" ; |
| 845 | return diag; |
| 846 | } |
| 847 | |
| 848 | rewriter.eraseOp(linalgOp); |
| 849 | return DiagnosedSilenceableFailure::success(); |
| 850 | } |
| 851 | |
| 852 | //===----------------------------------------------------------------------===// |
| 853 | // Hopper builders. |
| 854 | //===----------------------------------------------------------------------===// |
| 855 | |
| 856 | /// Helper to create the base Hopper-specific operations that are reused in |
| 857 | /// various other places. |
| 858 | struct HopperBuilder { |
| 859 | HopperBuilder(RewriterBase &rewriter, Location loc) |
| 860 | : rewriter(rewriter), loc(loc) {} |
| 861 | |
| 862 | TypedValue<nvgpu::MBarrierGroupType> |
| 863 | buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); |
| 864 | |
| 865 | /// Create tma descriptor op to initiate transfer from global to shared |
| 866 | /// memory. This must be done before the launch op, on the host. |
| 867 | TypedValue<nvgpu::TensorMapDescriptorType> |
| 868 | buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
| 869 | gpu::LaunchOp launchOp); |
| 870 | |
| 871 | /// Build a tma load from global memory to shared memory using `barrier` to |
| 872 | /// synchronize. Return the number of bytes that will be transferred. |
| 873 | OpFoldResult |
| 874 | buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
| 875 | TypedValue<MemRefType> sharedMemref, |
| 876 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
| 877 | SmallVectorImpl<Operation *> &loadOps); |
| 878 | void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, |
| 879 | ArrayRef<OpFoldResult> sizes); |
| 880 | |
| 881 | /// If threadIdx.x == 0 does TMA request + wait, else just wait. |
| 882 | /// Return the operation that performs the transfer on thread0. |
| 883 | // TODO: In the future, don't hardcode to thread 0 but elect a leader. |
| 884 | SmallVector<Operation *> buildPredicateLoadsOnThread0( |
| 885 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
| 886 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
| 887 | TypedValue<nvgpu::MBarrierGroupType> barrier); |
| 888 | |
| 889 | void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); |
| 890 | |
| 891 | RewriterBase &rewriter; |
| 892 | Location loc; |
| 893 | }; |
| 894 | |
| 895 | SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( |
| 896 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
| 897 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
| 898 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
| 899 | SmallVector<Operation *> loadOps; |
| 900 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 901 | Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
| 902 | Value cond = |
| 903 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero); |
| 904 | // clang-format off |
| 905 | rewriter.create<scf::IfOp>( |
| 906 | /*location=*/loc, |
| 907 | /*conditional=*/cond, |
| 908 | /*thenBuilder=*/ |
| 909 | [&](OpBuilder &lb, Location loc) { |
| 910 | SmallVector<OpFoldResult> sizes; |
| 911 | sizes.reserve(N: globalDescriptors.size()); |
| 912 | for (auto [desc, shmem] : llvm::zip_equal( |
| 913 | globalDescriptors, sharedMemBuffers)) { |
| 914 | OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); |
| 915 | sizes.push_back(sz); |
| 916 | } |
| 917 | // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. |
| 918 | // This may or may not have perf implications. |
| 919 | buildBarrierArriveTx(barrier, sizes); |
| 920 | rewriter.create<scf::YieldOp>(loc); |
| 921 | }, |
| 922 | /*elseBuilder=*/ |
| 923 | [&](OpBuilder &lb, Location loc) { |
| 924 | // TODO: is this for no-thread divergence? |
| 925 | // Should we just yield the size and hoist? |
| 926 | buildBarrierArriveTx(barrier, sizes: getAsIndexOpFoldResult(ctx: rewriter.getContext(), val: 0)); |
| 927 | rewriter.create<scf::YieldOp>(loc); |
| 928 | }); |
| 929 | // clang-format on |
| 930 | return loadOps; |
| 931 | } |
| 932 | |
| 933 | static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { |
| 934 | return gpu::AddressSpaceAttr::get( |
| 935 | b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
| 936 | // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); |
| 937 | } |
| 938 | |
| 939 | TypedValue<nvgpu::MBarrierGroupType> |
| 940 | HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { |
| 941 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
| 942 | Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>( |
| 943 | loc, |
| 944 | nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); |
| 945 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 946 | rewriter.create<nvgpu::MBarrierInitOp>( |
| 947 | loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), |
| 948 | zero, Value()); |
| 949 | rewriter.create<gpu::BarrierOp>(loc); |
| 950 | return cast<TypedValue<nvgpu::MBarrierGroupType>>(Val&: barrier); |
| 951 | } |
| 952 | |
| 953 | TypedValue<nvgpu::TensorMapDescriptorType> |
| 954 | HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
| 955 | gpu::LaunchOp launchOp) { |
| 956 | OpBuilder::InsertionGuard guard(rewriter); |
| 957 | rewriter.setInsertionPoint(launchOp); |
| 958 | Value unrankedMemRef = rewriter.create<memref::CastOp>( |
| 959 | loc, |
| 960 | UnrankedMemRefType::get(memref.getType().getElementType(), |
| 961 | memref.getType().getMemorySpace()), |
| 962 | memref); |
| 963 | SmallVector<OpFoldResult> mixedSizes = |
| 964 | memref::getMixedSizes(builder&: rewriter, loc, value: memref); |
| 965 | SmallVector<Value> sizes = |
| 966 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: mixedSizes); |
| 967 | |
| 968 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
| 969 | Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>( |
| 970 | loc, |
| 971 | nvgpu::TensorMapDescriptorType::get( |
| 972 | rewriter.getContext(), |
| 973 | MemRefType::Builder(memref.getType()) |
| 974 | .setMemorySpace(sharedMemorySpace), |
| 975 | TensorMapSwizzleKind::SWIZZLE_NONE, |
| 976 | TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, |
| 977 | TensorMapInterleaveKind::INTERLEAVE_NONE), |
| 978 | unrankedMemRef, sizes); |
| 979 | return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); |
| 980 | } |
| 981 | |
| 982 | OpFoldResult HopperBuilder::buildTmaAsyncLoad( |
| 983 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
| 984 | TypedValue<MemRefType> sharedMemref, |
| 985 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
| 986 | SmallVectorImpl<Operation *> &loadOps) { |
| 987 | MLIRContext *ctx = rewriter.getContext(); |
| 988 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 989 | Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>( |
| 990 | loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, |
| 991 | Value(), Value()); |
| 992 | loadOps.push_back(Elt: loadOp); |
| 993 | auto mixedSizes = memref::getMixedSizes(builder&: rewriter, loc, value: sharedMemref); |
| 994 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
| 995 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
| 996 | AffineExpr prodExprInBytes = |
| 997 | computeProduct(ctx, basis: symbols) * |
| 998 | (sharedMemref.getType().getElementTypeBitWidth() / 8); |
| 999 | auto res = affine::makeComposedFoldedAffineApply(b&: rewriter, loc, |
| 1000 | expr: prodExprInBytes, operands: mixedSizes); |
| 1001 | return res; |
| 1002 | } |
| 1003 | |
| 1004 | void HopperBuilder::buildBarrierArriveTx( |
| 1005 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
| 1006 | ArrayRef<OpFoldResult> mixedSizes) { |
| 1007 | assert(!mixedSizes.empty() && "expecte non-empty sizes" ); |
| 1008 | MLIRContext *ctx = rewriter.getContext(); |
| 1009 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
| 1010 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
| 1011 | AffineExpr sumExpr = computeSum(ctx, basis: symbols); |
| 1012 | OpFoldResult size = |
| 1013 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: mixedSizes); |
| 1014 | Value sizeVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: size); |
| 1015 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1016 | rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero, |
| 1017 | Value()); |
| 1018 | } |
| 1019 | |
| 1020 | void HopperBuilder::buildTryWaitParity( |
| 1021 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
| 1022 | Type i1 = rewriter.getI1Type(); |
| 1023 | Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0); |
| 1024 | // 10M is an arbitrary, not too small or too big number to specify the number |
| 1025 | // of ticks before retry. |
| 1026 | // TODO: hoist this in a default dialect constant. |
| 1027 | Value ticksBeforeRetry = |
| 1028 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 10000000); |
| 1029 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1030 | rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity, |
| 1031 | ticksBeforeRetry, zero); |
| 1032 | } |
| 1033 | |
| 1034 | //===----------------------------------------------------------------------===// |
| 1035 | // RewriteCopyAsTmaOp |
| 1036 | //===----------------------------------------------------------------------===// |
| 1037 | |
| 1038 | /// Helper to create the tma operations corresponding to `linalg::CopyOp`. |
| 1039 | struct CopyBuilder : public HopperBuilder { |
| 1040 | CopyBuilder(RewriterBase &rewriter, Location loc) |
| 1041 | : HopperBuilder(rewriter, loc) {} |
| 1042 | |
| 1043 | SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps); |
| 1044 | }; |
| 1045 | |
| 1046 | SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { |
| 1047 | MLIRContext *ctx = rewriter.getContext(); |
| 1048 | if (copyOps.empty()) |
| 1049 | return SmallVector<Operation *>(); |
| 1050 | |
| 1051 | auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>(); |
| 1052 | assert(launchOp && "expected launch op" ); |
| 1053 | |
| 1054 | // 1. Init a barrier object in shared memory. |
| 1055 | OpBuilder::InsertionGuard g(rewriter); |
| 1056 | rewriter.setInsertionPoint(copyOps.front()); |
| 1057 | AffineExpr bx, by, bz; |
| 1058 | bindSymbols(ctx, exprs&: bx, exprs&: by, exprs&: bz); |
| 1059 | AffineExpr prod = computeProduct(ctx, basis: ArrayRef<AffineExpr>{bx, by, bz}); |
| 1060 | OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( |
| 1061 | b&: rewriter, loc, expr: prod, |
| 1062 | operands: ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), |
| 1063 | launchOp.getBlockSizeZ()}); |
| 1064 | |
| 1065 | TypedValue<nvgpu::MBarrierGroupType> barrier = |
| 1066 | buildAndInitBarrierInSharedMemory(numThreads); |
| 1067 | |
| 1068 | SmallVector<TypedValue<MemRefType>> shmems; |
| 1069 | SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; |
| 1070 | for (Operation *op : copyOps) { |
| 1071 | auto copyOp = cast<linalg::CopyOp>(op); |
| 1072 | auto inMemRef = |
| 1073 | cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get()); |
| 1074 | assert(inMemRef.getType().getRank() == 2 && |
| 1075 | "expected in to be a 2D memref" ); |
| 1076 | |
| 1077 | // 2. Build global memory descriptor. |
| 1078 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = |
| 1079 | buildGlobalMemRefDescriptor(inMemRef, launchOp); |
| 1080 | globalDescs.push_back(globalDesc); |
| 1081 | |
| 1082 | // 3. Shared memory and descriptor for the tmp array. |
| 1083 | auto shmem = |
| 1084 | cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get()); |
| 1085 | shmems.push_back(Elt: shmem); |
| 1086 | } |
| 1087 | |
| 1088 | // 4. Load in from global memory to shared memory using tma. |
| 1089 | OpBuilder::InsertionGuard g2(rewriter); |
| 1090 | rewriter.setInsertionPoint(copyOps.front()); |
| 1091 | SmallVector<Operation *> results = |
| 1092 | buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); |
| 1093 | |
| 1094 | // 5. Spin-loop until data is ready. |
| 1095 | buildTryWaitParity(barrier); |
| 1096 | |
| 1097 | // 6. Erase the ops that have now been rewritten. |
| 1098 | for (Operation *op : copyOps) |
| 1099 | rewriter.eraseOp(op); |
| 1100 | |
| 1101 | return results; |
| 1102 | } |
| 1103 | |
| 1104 | DiagnosedSilenceableFailure |
| 1105 | transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, |
| 1106 | transform::TransformResults &results, |
| 1107 | transform::TransformState &state) { |
| 1108 | auto payloadOps = state.getPayloadOps(getTarget()); |
| 1109 | gpu::LaunchOp commonLaunchOp; |
| 1110 | Operation *firstOp, *failingOp; |
| 1111 | if (llvm::any_of(payloadOps, [&](Operation *op) { |
| 1112 | if (!commonLaunchOp) { |
| 1113 | commonLaunchOp = op->getParentOfType<gpu::LaunchOp>(); |
| 1114 | firstOp = op; |
| 1115 | } |
| 1116 | auto fail = !op->getParentOfType<gpu::LaunchOp>() || |
| 1117 | commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() || |
| 1118 | !isa<linalg::CopyOp>(op); |
| 1119 | if (fail) |
| 1120 | failingOp = op; |
| 1121 | return fail; |
| 1122 | })) { |
| 1123 | DiagnosedSilenceableFailure diag = |
| 1124 | emitSilenceableError() |
| 1125 | << "target ops must be linalg::CopyOp nested under a common " |
| 1126 | "gpu.LaunchOp to be rewritten because the tma descriptors need to " |
| 1127 | "be created on the host.\nBut got: " |
| 1128 | << *firstOp << "\nand " << *failingOp; |
| 1129 | return diag; |
| 1130 | } |
| 1131 | |
| 1132 | // TODO: more robust detection of copy, with transposes etc. |
| 1133 | CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); |
| 1134 | |
| 1135 | return DiagnosedSilenceableFailure::success(); |
| 1136 | } |
| 1137 | |
| 1138 | //===----------------------------------------------------------------------===// |
| 1139 | // Transform op registration |
| 1140 | //===----------------------------------------------------------------------===// |
| 1141 | |
| 1142 | namespace { |
| 1143 | class NVGPUTransformDialectExtension |
| 1144 | : public transform::TransformDialectExtension< |
| 1145 | NVGPUTransformDialectExtension> { |
| 1146 | public: |
| 1147 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) |
| 1148 | |
| 1149 | NVGPUTransformDialectExtension() { |
| 1150 | declareGeneratedDialect<arith::ArithDialect>(); |
| 1151 | declareGeneratedDialect<affine::AffineDialect>(); |
| 1152 | declareGeneratedDialect<nvgpu::NVGPUDialect>(); |
| 1153 | declareGeneratedDialect<NVVM::NVVMDialect>(); |
| 1154 | declareGeneratedDialect<vector::VectorDialect>(); |
| 1155 | registerTransformOps< |
| 1156 | #define GET_OP_LIST |
| 1157 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
| 1158 | >(); |
| 1159 | } |
| 1160 | }; |
| 1161 | } // namespace |
| 1162 | |
| 1163 | #define GET_OP_CLASSES |
| 1164 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
| 1165 | |
| 1166 | void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 1167 | registry.addExtensions<NVGPUTransformDialectExtension>(); |
| 1168 | } |
| 1169 | |