1//===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir::amdgpu {
22#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
24} // namespace mlir::amdgpu
25
26using namespace mlir;
27using namespace mlir::amdgpu;
28
29namespace {
30struct AmdgpuEmulateAtomicsPass
31 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32 AmdgpuEmulateAtomicsPass> {
33 using AmdgpuEmulateAtomicsPassBase<
34 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35 void runOnOperation() override;
36};
37
38template <typename AtomicOp, typename ArithOp>
39struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
40 using OpConversionPattern<AtomicOp>::OpConversionPattern;
41 using Adaptor = typename AtomicOp::Adaptor;
42
43 LogicalResult
44 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
45 ConversionPatternRewriter &rewriter) const override;
46};
47} // namespace
48
49namespace {
50enum class DataArgAction : unsigned char {
51 Duplicate,
52 Drop,
53};
54} // namespace
55
56// Fix up the fact that, when we're migrating from a general bugffer atomic
57// to a load or to a CAS, the number of openrands, and thus the number of
58// entries needed in operandSegmentSizes, needs to change. We use this method
59// because we'd like to preserve unknown attributes on the atomic instead of
60// discarding them.
61static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
62 SmallVectorImpl<NamedAttribute> &newAttrs,
63 DataArgAction action) {
64 newAttrs.reserve(N: attrs.size());
65 for (NamedAttribute attr : attrs) {
66 if (attr.getName().getValue() != "operandSegmentSizes") {
67 newAttrs.push_back(Elt: attr);
68 continue;
69 }
70 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
71 MLIRContext *context = segmentAttr.getContext();
72 DenseI32ArrayAttr newSegments;
73 switch (action) {
74 case DataArgAction::Drop:
75 newSegments = DenseI32ArrayAttr::get(
76 context, segmentAttr.asArrayRef().drop_front());
77 break;
78 case DataArgAction::Duplicate: {
79 SmallVector<int32_t> newVals;
80 ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
81 newVals.push_back(Elt: oldVals[0]);
82 newVals.append(in_start: oldVals.begin(), in_end: oldVals.end());
83 newSegments = DenseI32ArrayAttr::get(context, newVals);
84 break;
85 }
86 }
87 newAttrs.push_back(Elt: NamedAttribute(attr.getName(), newSegments));
88 }
89}
90
91// A helper function to flatten a vector value to a scalar containing its bits,
92// returning the value itself if othetwise.
93static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
94 Value val) {
95 auto vectorType = dyn_cast<VectorType>(val.getType());
96 if (!vectorType)
97 return val;
98
99 int64_t bitwidth =
100 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101 Type allBitsType = rewriter.getIntegerType(bitwidth);
102 auto allBitsVecType = VectorType::get({1}, allBitsType);
103 Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104 Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105 return scalar;
106}
107
108template <typename AtomicOp, typename ArithOp>
109LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110 AtomicOp atomicOp, Adaptor adaptor,
111 ConversionPatternRewriter &rewriter) const {
112 Location loc = atomicOp.getLoc();
113
114 ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
115 ValueRange operands = adaptor.getOperands();
116 Value data = operands.take_front()[0];
117 ValueRange invariantArgs = operands.drop_front();
118 Type dataType = data.getType();
119
120 SmallVector<NamedAttribute> loadAttrs;
121 patchOperandSegmentSizes(attrs: origAttrs, newAttrs&: loadAttrs, action: DataArgAction::Drop);
122 Value initialLoad =
123 rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
124 Block *currentBlock = rewriter.getInsertionBlock();
125 Block *afterAtomic =
126 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
127 Block *loopBlock = rewriter.createBlock(insertBefore: afterAtomic, argTypes: {dataType}, locs: {loc});
128
129 rewriter.setInsertionPointToEnd(currentBlock);
130 rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
131
132 rewriter.setInsertionPointToEnd(loopBlock);
133 Value prevLoad = loopBlock->getArgument(i: 0);
134 Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135 dataType = operated.getType();
136
137 SmallVector<NamedAttribute> cmpswapAttrs;
138 patchOperandSegmentSizes(attrs: origAttrs, newAttrs&: cmpswapAttrs, action: DataArgAction::Duplicate);
139 SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140 cmpswapArgs.append(in_start: invariantArgs.begin(), in_end: invariantArgs.end());
141 Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
142 loc, dataType, cmpswapArgs, cmpswapAttrs);
143
144 // We care about exact bitwise equality here, so do some bitcasts.
145 // These will fold away during lowering to the ROCDL dialect, where
146 // an int->float bitcast is introduced to account for the fact that cmpswap
147 // only takes integer arguments.
148
149 Value prevLoadForCompare = flattenVecToBits(rewriter, loc, val: prevLoad);
150 Value atomicResForCompare = flattenVecToBits(rewriter, loc, val: atomicRes);
151 if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153 prevLoadForCompare =
154 rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
155 atomicResForCompare =
156 rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
157 }
158 Value canLeave = rewriter.create<arith::CmpIOp>(
159 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160 rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
161 loopBlock, atomicRes);
162 rewriter.eraseOp(op: atomicOp);
163 return success();
164}
165
166void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
167 ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
168 // gfx10 has no atomic adds.
169 if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
170 target.addIllegalOp<RawBufferAtomicFaddOp>();
171 }
172 // gfx11 has no fp16 atomics
173 if (chipset.majorVersion == 11) {
174 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175 [](RawBufferAtomicFaddOp op) -> bool {
176 Type elemType = getElementTypeOrSelf(op.getValue().getType());
177 return !isa<Float16Type, BFloat16Type>(elemType);
178 });
179 }
180 // gfx9 has no to a very limited support for floating-point min and max.
181 if (chipset.majorVersion == 9) {
182 if (chipset >= Chipset(9, 0, 0xa)) {
183 // gfx90a supports f64 max (and min, but we don't have a min wrapper right
184 // now) but all other types need to be emulated.
185 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
186 [](RawBufferAtomicFmaxOp op) -> bool {
187 return op.getValue().getType().isF64();
188 });
189 } else {
190 target.addIllegalOp<RawBufferAtomicFmaxOp>();
191 }
192 // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
193 // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
194 if (chipset < Chipset(9, 5, 0)) {
195 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
196 [](RawBufferAtomicFaddOp op) -> bool {
197 Type elemType = getElementTypeOrSelf(op.getValue().getType());
198 return !isa<BFloat16Type>(elemType);
199 });
200 }
201 }
202 patterns.add<
203 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
204 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
205 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
207 patterns.getContext());
208}
209
210void AmdgpuEmulateAtomicsPass::runOnOperation() {
211 Operation *op = getOperation();
212 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
213 if (failed(Result: maybeChipset)) {
214 emitError(op->getLoc(), "Invalid chipset name: " + chipset);
215 return signalPassFailure();
216 }
217
218 MLIRContext &ctx = getContext();
219 ConversionTarget target(ctx);
220 RewritePatternSet patterns(&ctx);
221 target.markUnknownOpDynamicallyLegal(
222 fn: [](Operation *op) -> bool { return true; });
223
224 populateAmdgpuEmulateAtomicsPatterns(target, patterns, chipset: *maybeChipset);
225 if (failed(applyPartialConversion(op, target, std::move(patterns))))
226 return signalPassFailure();
227}
228

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp