| 1 | //===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the types and operation details for the ROCDL IR dialect in |
| 10 | // MLIR, and the LLVM IR dialect. It also registers the dialect. |
| 11 | // |
| 12 | // The ROCDL dialect only contains GPU specific additions on top of the general |
| 13 | // LLVM dialect. |
| 14 | // |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | |
| 17 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| 18 | |
| 19 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
| 20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 21 | #include "mlir/IR/Builders.h" |
| 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | #include "mlir/IR/DialectImplementation.h" |
| 24 | #include "mlir/IR/MLIRContext.h" |
| 25 | #include "mlir/IR/Operation.h" |
| 26 | #include "llvm/ADT/TypeSwitch.h" |
| 27 | #include "llvm/AsmParser/Parser.h" |
| 28 | #include "llvm/IR/Attributes.h" |
| 29 | #include "llvm/IR/Function.h" |
| 30 | #include "llvm/IR/Type.h" |
| 31 | #include "llvm/Support/SourceMgr.h" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace ROCDL; |
| 35 | |
| 36 | #include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.cpp.inc" |
| 37 | |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | // Parsing for ROCDL ops |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | |
| 42 | // <operation> ::= |
| 43 | // `llvm.amdgcn.raw.buffer.load.* %rsrc, %offset, %soffset, %aux |
| 44 | // : result_type` |
| 45 | ParseResult RawBufferLoadOp::parse(OpAsmParser &parser, |
| 46 | OperationState &result) { |
| 47 | SmallVector<OpAsmParser::UnresolvedOperand, 4> ops; |
| 48 | Type type; |
| 49 | if (parser.parseOperandList(ops, 4) || parser.parseColonType(type) || |
| 50 | parser.addTypeToList(type, result.types)) |
| 51 | return failure(); |
| 52 | |
| 53 | auto bldr = parser.getBuilder(); |
| 54 | auto int32Ty = bldr.getI32Type(); |
| 55 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 56 | return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 57 | parser.getNameLoc(), result.operands); |
| 58 | } |
| 59 | |
| 60 | void RawBufferLoadOp::print(OpAsmPrinter &p) { |
| 61 | p << " " << getOperands() << " : " << getRes().getType(); |
| 62 | } |
| 63 | |
| 64 | // <operation> ::= |
| 65 | // `llvm.amdgcn.raw.buffer.store.* %vdata, %rsrc, %offset, |
| 66 | // %soffset, %aux : result_type` |
| 67 | ParseResult RawBufferStoreOp::parse(OpAsmParser &parser, |
| 68 | OperationState &result) { |
| 69 | SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; |
| 70 | Type type; |
| 71 | if (parser.parseOperandList(ops, 5) || parser.parseColonType(type)) |
| 72 | return failure(); |
| 73 | |
| 74 | auto bldr = parser.getBuilder(); |
| 75 | auto int32Ty = bldr.getI32Type(); |
| 76 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 77 | |
| 78 | if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 79 | parser.getNameLoc(), result.operands)) |
| 80 | return failure(); |
| 81 | return success(); |
| 82 | } |
| 83 | |
| 84 | void RawBufferStoreOp::print(OpAsmPrinter &p) { |
| 85 | p << " " << getOperands() << " : " << getVdata().getType(); |
| 86 | } |
| 87 | |
| 88 | // <operation> ::= |
| 89 | // `llvm.amdgcn.raw.buffer.atomic.fadd.* %vdata, %rsrc, %offset, |
| 90 | // %soffset, %aux : result_type` |
| 91 | ParseResult RawBufferAtomicFAddOp::parse(OpAsmParser &parser, |
| 92 | OperationState &result) { |
| 93 | SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; |
| 94 | Type type; |
| 95 | if (parser.parseOperandList(ops, 5) || parser.parseColonType(type)) |
| 96 | return failure(); |
| 97 | |
| 98 | auto bldr = parser.getBuilder(); |
| 99 | auto int32Ty = bldr.getI32Type(); |
| 100 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 101 | |
| 102 | if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 103 | parser.getNameLoc(), result.operands)) |
| 104 | return failure(); |
| 105 | return success(); |
| 106 | } |
| 107 | |
| 108 | void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) { |
| 109 | p << " " << getOperands() << " : " << getVdata().getType(); |
| 110 | } |
| 111 | |
| 112 | // <operation> ::= |
| 113 | // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset, |
| 114 | // %soffset, %aux : result_type` |
| 115 | ParseResult RawBufferAtomicFMaxOp::parse(OpAsmParser &parser, |
| 116 | OperationState &result) { |
| 117 | SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; |
| 118 | Type type; |
| 119 | if (parser.parseOperandList(ops, 5) || parser.parseColonType(type)) |
| 120 | return failure(); |
| 121 | |
| 122 | auto bldr = parser.getBuilder(); |
| 123 | auto int32Ty = bldr.getI32Type(); |
| 124 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 125 | |
| 126 | if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 127 | parser.getNameLoc(), result.operands)) |
| 128 | return failure(); |
| 129 | return success(); |
| 130 | } |
| 131 | |
| 132 | void RawBufferAtomicFMaxOp::print(mlir::OpAsmPrinter &p) { |
| 133 | p << " " << getOperands() << " : " << getVdata().getType(); |
| 134 | } |
| 135 | |
| 136 | // <operation> ::= |
| 137 | // `llvm.amdgcn.raw.buffer.atomic.smax.* %vdata, %rsrc, %offset, |
| 138 | // %soffset, %aux : result_type` |
| 139 | ParseResult RawBufferAtomicSMaxOp::parse(OpAsmParser &parser, |
| 140 | OperationState &result) { |
| 141 | SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; |
| 142 | Type type; |
| 143 | if (parser.parseOperandList(ops, 5) || parser.parseColonType(type)) |
| 144 | return failure(); |
| 145 | |
| 146 | auto bldr = parser.getBuilder(); |
| 147 | auto int32Ty = bldr.getI32Type(); |
| 148 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 149 | |
| 150 | if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 151 | parser.getNameLoc(), result.operands)) |
| 152 | return failure(); |
| 153 | return success(); |
| 154 | } |
| 155 | |
| 156 | void RawBufferAtomicSMaxOp::print(mlir::OpAsmPrinter &p) { |
| 157 | p << " " << getOperands() << " : " << getVdata().getType(); |
| 158 | } |
| 159 | |
| 160 | // <operation> ::= |
| 161 | // `llvm.amdgcn.raw.buffer.atomic.umin.* %vdata, %rsrc, %offset, |
| 162 | // %soffset, %aux : result_type` |
| 163 | ParseResult RawBufferAtomicUMinOp::parse(OpAsmParser &parser, |
| 164 | OperationState &result) { |
| 165 | SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; |
| 166 | Type type; |
| 167 | if (parser.parseOperandList(ops, 5) || parser.parseColonType(type)) |
| 168 | return failure(); |
| 169 | |
| 170 | auto bldr = parser.getBuilder(); |
| 171 | auto int32Ty = bldr.getI32Type(); |
| 172 | auto i32x4Ty = VectorType::get({4}, int32Ty); |
| 173 | |
| 174 | if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty}, |
| 175 | parser.getNameLoc(), result.operands)) |
| 176 | return failure(); |
| 177 | return success(); |
| 178 | } |
| 179 | |
| 180 | void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) { |
| 181 | p << " " << getOperands() << " : " << getVdata().getType(); |
| 182 | } |
| 183 | |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | // ROCDLDialect initialization, type parsing, and registration. |
| 186 | //===----------------------------------------------------------------------===// |
| 187 | |
| 188 | // TODO: This should be the llvm.rocdl dialect once this is supported. |
| 189 | void ROCDLDialect::initialize() { |
| 190 | addOperations< |
| 191 | #define GET_OP_LIST |
| 192 | #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" |
| 193 | >(); |
| 194 | |
| 195 | addAttributes< |
| 196 | #define GET_ATTRDEF_LIST |
| 197 | #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc" |
| 198 | >(); |
| 199 | |
| 200 | // Support unknown operations because not all ROCDL operations are registered. |
| 201 | allowUnknownOperations(); |
| 202 | declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>(); |
| 203 | } |
| 204 | |
| 205 | LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op, |
| 206 | NamedAttribute attr) { |
| 207 | // Kernel function attribute should be attached to functions. |
| 208 | if (kernelAttrName.getName() == attr.getName()) { |
| 209 | if (!isa<LLVM::LLVMFuncOp>(op)) { |
| 210 | return op->emitError() << "'" << kernelAttrName.getName() |
| 211 | << "' attribute attached to unexpected op" ; |
| 212 | } |
| 213 | } |
| 214 | return success(); |
| 215 | } |
| 216 | |
| 217 | //===----------------------------------------------------------------------===// |
| 218 | // ROCDL target attribute. |
| 219 | //===----------------------------------------------------------------------===// |
| 220 | LogicalResult |
| 221 | ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 222 | int optLevel, StringRef triple, StringRef chip, |
| 223 | StringRef features, StringRef abiVersion, |
| 224 | DictionaryAttr flags, ArrayAttr files) { |
| 225 | if (optLevel < 0 || optLevel > 3) { |
| 226 | emitError() << "The optimization level must be a number between 0 and 3." ; |
| 227 | return failure(); |
| 228 | } |
| 229 | if (triple.empty()) { |
| 230 | emitError() << "The target triple cannot be empty." ; |
| 231 | return failure(); |
| 232 | } |
| 233 | if (chip.empty()) { |
| 234 | emitError() << "The target chip cannot be empty." ; |
| 235 | return failure(); |
| 236 | } |
| 237 | if (abiVersion != "400" && abiVersion != "500" && abiVersion != "600" ) { |
| 238 | emitError() << "Invalid ABI version, it must be `400`, `500` or '600'." ; |
| 239 | return failure(); |
| 240 | } |
| 241 | if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { |
| 242 | return attr && mlir::isa<StringAttr>(attr); |
| 243 | })) { |
| 244 | emitError() << "All the elements in the `link` array must be strings." ; |
| 245 | return failure(); |
| 246 | } |
| 247 | return success(); |
| 248 | } |
| 249 | |
| 250 | #define GET_OP_CLASSES |
| 251 | #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" |
| 252 | |
| 253 | #define GET_ATTRDEF_CLASSES |
| 254 | #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc" |
| 255 | |