| 1 | //===-- CUFOps.cpp --------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
| 14 | #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" |
| 15 | #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" |
| 16 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
| 17 | #include "flang/Optimizer/Dialect/FIRType.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 20 | #include "mlir/IR/Attributes.h" |
| 21 | #include "mlir/IR/BuiltinAttributes.h" |
| 22 | #include "mlir/IR/BuiltinOps.h" |
| 23 | #include "mlir/IR/Diagnostics.h" |
| 24 | #include "mlir/IR/Matchers.h" |
| 25 | #include "mlir/IR/OpDefinition.h" |
| 26 | #include "mlir/IR/PatternMatch.h" |
| 27 | #include "llvm/ADT/SmallVector.h" |
| 28 | |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | // AllocOp |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | |
| 33 | static mlir::Type wrapAllocaResultType(mlir::Type intype) { |
| 34 | if (mlir::isa<fir::ReferenceType>(intype)) |
| 35 | return {}; |
| 36 | return fir::ReferenceType::get(intype); |
| 37 | } |
| 38 | |
| 39 | void cuf::AllocOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 40 | mlir::Type inType, llvm::StringRef uniqName, |
| 41 | llvm::StringRef bindcName, |
| 42 | cuf::DataAttributeAttr cudaAttr, |
| 43 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 44 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 45 | mlir::StringAttr nameAttr = |
| 46 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
| 47 | mlir::StringAttr bindcAttr = |
| 48 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
| 49 | build(builder, result, wrapAllocaResultType(inType), |
| 50 | mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, |
| 51 | cudaAttr); |
| 52 | result.addAttributes(attributes); |
| 53 | } |
| 54 | |
| 55 | template <typename Op> |
| 56 | static llvm::LogicalResult checkCudaAttr(Op op) { |
| 57 | if (op.getDataAttr() == cuf::DataAttribute::Device || |
| 58 | op.getDataAttr() == cuf::DataAttribute::Managed || |
| 59 | op.getDataAttr() == cuf::DataAttribute::Unified || |
| 60 | op.getDataAttr() == cuf::DataAttribute::Pinned || |
| 61 | op.getDataAttr() == cuf::DataAttribute::Shared) |
| 62 | return mlir::success(); |
| 63 | return op.emitOpError() |
| 64 | << "expect device, managed, pinned or unified cuda attribute" ; |
| 65 | } |
| 66 | |
| 67 | llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); } |
| 68 | |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | // FreeOp |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | |
| 73 | llvm::LogicalResult cuf::FreeOp::verify() { return checkCudaAttr(*this); } |
| 74 | |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | // AllocateOp |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | |
| 79 | template <typename OpTy> |
| 80 | static llvm::LogicalResult checkStreamType(OpTy op) { |
| 81 | if (!op.getStream()) |
| 82 | return mlir::success(); |
| 83 | if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getStream().getType())) |
| 84 | if (!refTy.getEleTy().isInteger(64)) |
| 85 | return op.emitOpError("stream is expected to be an i64 reference" ); |
| 86 | return mlir::success(); |
| 87 | } |
| 88 | |
| 89 | llvm::LogicalResult cuf::AllocateOp::verify() { |
| 90 | if (getPinned() && getStream()) |
| 91 | return emitOpError("pinned and stream cannot appears at the same time" ); |
| 92 | if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) |
| 93 | return emitOpError( |
| 94 | "expect box to be a reference to a class or box type value" ); |
| 95 | if (getSource() && |
| 96 | !mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getSource().getType()))) |
| 97 | return emitOpError( |
| 98 | "expect source to be a reference to/or a class or box type value" ); |
| 99 | if (getErrmsg() && |
| 100 | !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) |
| 101 | return emitOpError( |
| 102 | "expect errmsg to be a reference to/or a box type value" ); |
| 103 | if (getErrmsg() && !getHasStat()) |
| 104 | return emitOpError("expect stat attribute when errmsg is provided" ); |
| 105 | return checkStreamType(*this); |
| 106 | } |
| 107 | |
| 108 | //===----------------------------------------------------------------------===// |
| 109 | // DataTransferOp |
| 110 | //===----------------------------------------------------------------------===// |
| 111 | |
| 112 | llvm::LogicalResult cuf::DataTransferOp::verify() { |
| 113 | mlir::Type srcTy = getSrc().getType(); |
| 114 | mlir::Type dstTy = getDst().getType(); |
| 115 | if (getShape()) { |
| 116 | if (!fir::isa_ref_type(srcTy) && !fir::isa_ref_type(dstTy)) |
| 117 | return emitOpError() |
| 118 | << "shape can only be specified on data transfer with references" ; |
| 119 | } |
| 120 | if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) || |
| 121 | (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) || |
| 122 | (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) || |
| 123 | (fir::isa_box_type(srcTy) && fir::isa_ref_type(dstTy))) |
| 124 | return mlir::success(); |
| 125 | if (fir::isa_trivial(srcTy) && |
| 126 | matchPattern(getSrc().getDefiningOp(), mlir::m_Constant())) |
| 127 | return mlir::success(); |
| 128 | |
| 129 | return emitOpError() |
| 130 | << "expect src and dst to be references or descriptors or src to " |
| 131 | "be a constant: " |
| 132 | << srcTy << " - " << dstTy; |
| 133 | } |
| 134 | |
| 135 | //===----------------------------------------------------------------------===// |
| 136 | // DeallocateOp |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | |
| 139 | llvm::LogicalResult cuf::DeallocateOp::verify() { |
| 140 | if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) |
| 141 | return emitOpError( |
| 142 | "expect box to be a reference to class or box type value" ); |
| 143 | if (getErrmsg() && |
| 144 | !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) |
| 145 | return emitOpError( |
| 146 | "expect errmsg to be a reference to/or a box type value" ); |
| 147 | if (getErrmsg() && !getHasStat()) |
| 148 | return emitOpError("expect stat attribute when errmsg is provided" ); |
| 149 | return mlir::success(); |
| 150 | } |
| 151 | |
| 152 | //===----------------------------------------------------------------------===// |
| 153 | // KernelLaunchOp |
| 154 | //===----------------------------------------------------------------------===// |
| 155 | |
| 156 | llvm::LogicalResult cuf::KernelLaunchOp::verify() { |
| 157 | return checkStreamType(*this); |
| 158 | } |
| 159 | |
| 160 | //===----------------------------------------------------------------------===// |
| 161 | // KernelOp |
| 162 | //===----------------------------------------------------------------------===// |
| 163 | |
| 164 | llvm::SmallVector<mlir::Region *> cuf::KernelOp::getLoopRegions() { |
| 165 | return {&getRegion()}; |
| 166 | } |
| 167 | |
| 168 | mlir::ParseResult parseCUFKernelValues( |
| 169 | mlir::OpAsmParser &parser, |
| 170 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values, |
| 171 | llvm::SmallVectorImpl<mlir::Type> &types) { |
| 172 | if (mlir::succeeded(Result: parser.parseOptionalStar())) |
| 173 | return mlir::success(); |
| 174 | |
| 175 | if (mlir::succeeded(Result: parser.parseOptionalLParen())) { |
| 176 | if (mlir::failed(Result: parser.parseCommaSeparatedList( |
| 177 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
| 178 | if (parser.parseOperand(result&: values.emplace_back())) |
| 179 | return mlir::failure(); |
| 180 | return mlir::success(); |
| 181 | }))) |
| 182 | return mlir::failure(); |
| 183 | auto builder = parser.getBuilder(); |
| 184 | for (size_t i = 0; i < values.size(); i++) { |
| 185 | types.emplace_back(Args: builder.getI32Type()); |
| 186 | } |
| 187 | if (parser.parseRParen()) |
| 188 | return mlir::failure(); |
| 189 | } else { |
| 190 | if (parser.parseOperand(result&: values.emplace_back())) |
| 191 | return mlir::failure(); |
| 192 | auto builder = parser.getBuilder(); |
| 193 | types.emplace_back(Args: builder.getI32Type()); |
| 194 | return mlir::success(); |
| 195 | } |
| 196 | return mlir::success(); |
| 197 | } |
| 198 | |
| 199 | void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op, |
| 200 | mlir::ValueRange values, mlir::TypeRange types) { |
| 201 | if (values.empty()) |
| 202 | p << "*" ; |
| 203 | |
| 204 | if (values.size() > 1) |
| 205 | p << "(" ; |
| 206 | llvm::interleaveComma(c: values, os&: p, each_fn: [&p](mlir::Value v) { p << v; }); |
| 207 | if (values.size() > 1) |
| 208 | p << ")" ; |
| 209 | } |
| 210 | |
| 211 | mlir::ParseResult parseCUFKernelLoopControl( |
| 212 | mlir::OpAsmParser &parser, mlir::Region ®ion, |
| 213 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound, |
| 214 | llvm::SmallVectorImpl<mlir::Type> &lowerboundType, |
| 215 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound, |
| 216 | llvm::SmallVectorImpl<mlir::Type> &upperboundType, |
| 217 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step, |
| 218 | llvm::SmallVectorImpl<mlir::Type> &stepType) { |
| 219 | |
| 220 | llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars; |
| 221 | if (parser.parseLParen() || |
| 222 | parser.parseArgumentList(result&: inductionVars, |
| 223 | delimiter: mlir::OpAsmParser::Delimiter::None, |
| 224 | /*allowType=*/true) || |
| 225 | parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || |
| 226 | parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(), |
| 227 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
| 228 | parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() || |
| 229 | parser.parseKeyword(keyword: "to" ) || parser.parseLParen() || |
| 230 | parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(), |
| 231 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
| 232 | parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() || |
| 233 | parser.parseKeyword(keyword: "step" ) || parser.parseLParen() || |
| 234 | parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(), |
| 235 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
| 236 | parser.parseColonTypeList(result&: stepType) || parser.parseRParen()) |
| 237 | return mlir::failure(); |
| 238 | return parser.parseRegion(region, arguments: inductionVars); |
| 239 | } |
| 240 | |
| 241 | void printCUFKernelLoopControl( |
| 242 | mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion, |
| 243 | mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType, |
| 244 | mlir::ValueRange upperbound, mlir::TypeRange upperboundType, |
| 245 | mlir::ValueRange steps, mlir::TypeRange stepType) { |
| 246 | mlir::ValueRange regionArgs = region.front().getArguments(); |
| 247 | if (!regionArgs.empty()) { |
| 248 | p << "(" ; |
| 249 | llvm::interleaveComma( |
| 250 | c: regionArgs, os&: p, each_fn: [&p](mlir::Value v) { p << v << " : " << v.getType(); }); |
| 251 | p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" |
| 252 | << upperbound << " : " << upperboundType << ") " |
| 253 | << " step (" << steps << " : " << stepType << ") " ; |
| 254 | } |
| 255 | p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false); |
| 256 | } |
| 257 | |
| 258 | llvm::LogicalResult cuf::KernelOp::verify() { |
| 259 | if (getLowerbound().size() != getUpperbound().size() || |
| 260 | getLowerbound().size() != getStep().size()) |
| 261 | return emitOpError( |
| 262 | "expect same number of values in lowerbound, upperbound and step" ); |
| 263 | auto reduceAttrs = getReduceAttrs(); |
| 264 | std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0; |
| 265 | if (getReduceOperands().size() != reduceAttrsSize) |
| 266 | return emitOpError("expect same number of values in reduce operands and " |
| 267 | "reduce attributes" ); |
| 268 | if (reduceAttrs) { |
| 269 | for (const auto &attr : reduceAttrs.value()) { |
| 270 | if (!mlir::isa<fir::ReduceAttr>(attr)) |
| 271 | return emitOpError("expect reduce attributes to be ReduceAttr" ); |
| 272 | } |
| 273 | } |
| 274 | return checkStreamType(*this); |
| 275 | } |
| 276 | |
| 277 | //===----------------------------------------------------------------------===// |
| 278 | // RegisterKernelOp |
| 279 | //===----------------------------------------------------------------------===// |
| 280 | |
| 281 | mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() { |
| 282 | return getName().getRootReference(); |
| 283 | } |
| 284 | |
| 285 | mlir::StringAttr cuf::RegisterKernelOp::getKernelName() { |
| 286 | return getName().getLeafReference(); |
| 287 | } |
| 288 | |
| 289 | mlir::LogicalResult cuf::RegisterKernelOp::verify() { |
| 290 | if (getKernelName() == getKernelModuleName()) |
| 291 | return emitOpError("expect a module and a kernel name" ); |
| 292 | |
| 293 | auto mod = getOperation()->getParentOfType<mlir::ModuleOp>(); |
| 294 | if (!mod) |
| 295 | return emitOpError("expect to be in a module" ); |
| 296 | |
| 297 | mlir::SymbolTable symTab(mod); |
| 298 | auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName()); |
| 299 | if (!gpuMod) { |
| 300 | // If already a gpu.binary then stop the check here. |
| 301 | if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName())) |
| 302 | return mlir::success(); |
| 303 | return emitOpError("gpu module not found" ); |
| 304 | } |
| 305 | |
| 306 | mlir::SymbolTable gpuSymTab(gpuMod); |
| 307 | if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) { |
| 308 | if (!func.isKernel()) |
| 309 | return emitOpError("only kernel gpu.func can be registered" ); |
| 310 | return mlir::success(); |
| 311 | } else if (auto func = |
| 312 | gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) { |
| 313 | if (!func->getAttrOfType<mlir::UnitAttr>( |
| 314 | mlir::gpu::GPUDialect::getKernelFuncAttrName())) |
| 315 | return emitOpError("only gpu.kernel llvm.func can be registered" ); |
| 316 | return mlir::success(); |
| 317 | } |
| 318 | return emitOpError("device function not found" ); |
| 319 | } |
| 320 | |
| 321 | //===----------------------------------------------------------------------===// |
| 322 | // SharedMemoryOp |
| 323 | //===----------------------------------------------------------------------===// |
| 324 | |
| 325 | void cuf::SharedMemoryOp::build( |
| 326 | mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, |
| 327 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
| 328 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 329 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 330 | mlir::StringAttr nameAttr = |
| 331 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
| 332 | mlir::StringAttr bindcAttr = |
| 333 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
| 334 | build(builder, result, wrapAllocaResultType(inType), |
| 335 | mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, |
| 336 | /*offset=*/mlir::Value{}); |
| 337 | result.addAttributes(attributes); |
| 338 | } |
| 339 | |
| 340 | //===----------------------------------------------------------------------===// |
| 341 | // StreamCastOp |
| 342 | //===----------------------------------------------------------------------===// |
| 343 | |
| 344 | llvm::LogicalResult cuf::StreamCastOp::verify() { |
| 345 | return checkStreamType(*this); |
| 346 | } |
| 347 | |
| 348 | // Tablegen operators |
| 349 | |
| 350 | #define GET_OP_CLASSES |
| 351 | #include "flang/Optimizer/Dialect/CUF/CUFOps.cpp.inc" |
| 352 | |