1//===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20namespace mlir::amdgpu {
21#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23} // namespace mlir::amdgpu
24
25using namespace mlir;
26using namespace mlir::amdgpu;
27
28namespace {
29struct AmdgpuEmulateAtomicsPass
30 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
31 AmdgpuEmulateAtomicsPass> {
32 using AmdgpuEmulateAtomicsPassBase<
33 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
34 void runOnOperation() override;
35};
36
37template <typename AtomicOp, typename ArithOp>
38struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor = typename AtomicOp::Adaptor;
41
42 LogicalResult
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter) const override;
45};
46} // namespace
47
48namespace {
49enum class DataArgAction : unsigned char {
50 Duplicate,
51 Drop,
52};
53} // namespace
54
55// Fix up the fact that, when we're migrating from a general bugffer atomic
56// to a load or to a CAS, the number of openrands, and thus the number of
57// entries needed in operandSegmentSizes, needs to change. We use this method
58// because we'd like to preserve unknown attributes on the atomic instead of
59// discarding them.
60static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
61 SmallVectorImpl<NamedAttribute> &newAttrs,
62 DataArgAction action) {
63 newAttrs.reserve(N: attrs.size());
64 for (NamedAttribute attr : attrs) {
65 if (attr.getName().getValue() != "operandSegmentSizes") {
66 newAttrs.push_back(Elt: attr);
67 continue;
68 }
69 auto segmentAttr = cast<DenseI32ArrayAttr>(Val: attr.getValue());
70 MLIRContext *context = segmentAttr.getContext();
71 DenseI32ArrayAttr newSegments;
72 switch (action) {
73 case DataArgAction::Drop:
74 newSegments = DenseI32ArrayAttr::get(
75 context, content: segmentAttr.asArrayRef().drop_front());
76 break;
77 case DataArgAction::Duplicate: {
78 SmallVector<int32_t> newVals;
79 ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
80 newVals.push_back(Elt: oldVals[0]);
81 newVals.append(in_start: oldVals.begin(), in_end: oldVals.end());
82 newSegments = DenseI32ArrayAttr::get(context, content: newVals);
83 break;
84 }
85 }
86 newAttrs.push_back(Elt: NamedAttribute(attr.getName(), newSegments));
87 }
88}
89
90// A helper function to flatten a vector value to a scalar containing its bits,
91// returning the value itself if othetwise.
92static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
93 Value val) {
94 auto vectorType = dyn_cast<VectorType>(Val: val.getType());
95 if (!vectorType)
96 return val;
97
98 int64_t bitwidth =
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(width: bitwidth);
101 auto allBitsVecType = VectorType::get(shape: {1}, elementType: allBitsType);
102 Value bitcast = rewriter.create<vector::BitCastOp>(location: loc, args&: allBitsVecType, args&: val);
103 Value scalar = rewriter.create<vector::ExtractOp>(location: loc, args&: bitcast, args: 0);
104 return scalar;
105}
106
107template <typename AtomicOp, typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter) const {
111 Location loc = atomicOp.getLoc();
112
113 ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
114 ValueRange operands = adaptor.getOperands();
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
117 Type dataType = data.getType();
118
119 SmallVector<NamedAttribute> loadAttrs;
120 patchOperandSegmentSizes(attrs: origAttrs, newAttrs&: loadAttrs, action: DataArgAction::Drop);
121 Value initialLoad =
122 rewriter.create<RawBufferLoadOp>(location: loc, args&: dataType, args&: invariantArgs, args&: loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
124 Block *afterAtomic =
125 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
126 Block *loopBlock = rewriter.createBlock(insertBefore: afterAtomic, argTypes: {dataType}, locs: {loc});
127
128 rewriter.setInsertionPointToEnd(currentBlock);
129 rewriter.create<cf::BranchOp>(location: loc, args&: loopBlock, args&: initialLoad);
130
131 rewriter.setInsertionPointToEnd(loopBlock);
132 Value prevLoad = loopBlock->getArgument(i: 0);
133 Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
134 dataType = operated.getType();
135
136 SmallVector<NamedAttribute> cmpswapAttrs;
137 patchOperandSegmentSizes(attrs: origAttrs, newAttrs&: cmpswapAttrs, action: DataArgAction::Duplicate);
138 SmallVector<Value> cmpswapArgs = {operated, prevLoad};
139 cmpswapArgs.append(in_start: invariantArgs.begin(), in_end: invariantArgs.end());
140 Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
141 location: loc, args&: dataType, args&: cmpswapArgs, args&: cmpswapAttrs);
142
143 // We care about exact bitwise equality here, so do some bitcasts.
144 // These will fold away during lowering to the ROCDL dialect, where
145 // an int->float bitcast is introduced to account for the fact that cmpswap
146 // only takes integer arguments.
147
148 Value prevLoadForCompare = flattenVecToBits(rewriter, loc, val: prevLoad);
149 Value atomicResForCompare = flattenVecToBits(rewriter, loc, val: atomicRes);
150 if (auto floatDataTy = dyn_cast<FloatType>(Val&: dataType)) {
151 Type equivInt = rewriter.getIntegerType(width: floatDataTy.getWidth());
152 prevLoadForCompare =
153 rewriter.create<arith::BitcastOp>(location: loc, args&: equivInt, args&: prevLoad);
154 atomicResForCompare =
155 rewriter.create<arith::BitcastOp>(location: loc, args&: equivInt, args&: atomicRes);
156 }
157 Value canLeave = rewriter.create<arith::CmpIOp>(
158 location: loc, args: arith::CmpIPredicate::eq, args&: atomicResForCompare, args&: prevLoadForCompare);
159 rewriter.create<cf::CondBranchOp>(location: loc, args&: canLeave, args&: afterAtomic, args: ValueRange{},
160 args&: loopBlock, args&: atomicRes);
161 rewriter.eraseOp(op: atomicOp);
162 return success();
163}
164
165void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
166 ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset,
167 PatternBenefit benefit) {
168 // gfx10 has no atomic adds.
169 if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
170 target.addIllegalOp<RawBufferAtomicFaddOp>();
171 }
172 // gfx11 has no fp16 atomics
173 if (chipset.majorVersion == 11) {
174 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175 callback: [](RawBufferAtomicFaddOp op) -> bool {
176 Type elemType = getElementTypeOrSelf(type: op.getValue().getType());
177 return !isa<Float16Type, BFloat16Type>(Val: elemType);
178 });
179 }
180 // gfx9 has no to a very limited support for floating-point min and max.
181 if (chipset.majorVersion == 9) {
182 if (chipset >= Chipset(9, 0, 0xa)) {
183 // gfx90a supports f64 max (and min, but we don't have a min wrapper right
184 // now) but all other types need to be emulated.
185 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
186 callback: [](RawBufferAtomicFmaxOp op) -> bool {
187 return op.getValue().getType().isF64();
188 });
189 } else {
190 target.addIllegalOp<RawBufferAtomicFmaxOp>();
191 }
192 // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
193 // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
194 if (chipset < Chipset(9, 5, 0)) {
195 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
196 callback: [](RawBufferAtomicFaddOp op) -> bool {
197 Type elemType = getElementTypeOrSelf(type: op.getValue().getType());
198 return !isa<BFloat16Type>(Val: elemType);
199 });
200 }
201 }
202 patterns.add<
203 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
204 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
205 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
207 arg: patterns.getContext(), args&: benefit);
208}
209
210void AmdgpuEmulateAtomicsPass::runOnOperation() {
211 Operation *op = getOperation();
212 FailureOr<Chipset> maybeChipset = Chipset::parse(name: chipset);
213 if (failed(Result: maybeChipset)) {
214 emitError(loc: op->getLoc(), message: "Invalid chipset name: " + chipset);
215 return signalPassFailure();
216 }
217
218 MLIRContext &ctx = getContext();
219 ConversionTarget target(ctx);
220 RewritePatternSet patterns(&ctx);
221 target.markUnknownOpDynamicallyLegal(
222 fn: [](Operation *op) -> bool { return true; });
223
224 populateAmdgpuEmulateAtomicsPatterns(target, patterns, chipset: *maybeChipset);
225 if (failed(Result: applyPartialConversion(op, target, patterns: std::move(patterns))))
226 return signalPassFailure();
227}
228

source code of mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp