1//===-- CUFOps.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
14#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
15#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
16#include "flang/Optimizer/Dialect/FIRAttr.h"
17#include "flang/Optimizer/Dialect/FIRType.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Diagnostics.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/OpDefinition.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/SmallVector.h"
28
29//===----------------------------------------------------------------------===//
30// AllocOp
31//===----------------------------------------------------------------------===//
32
33static mlir::Type wrapAllocaResultType(mlir::Type intype) {
34 if (mlir::isa<fir::ReferenceType>(intype))
35 return {};
36 return fir::ReferenceType::get(intype);
37}
38
39void cuf::AllocOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
40 mlir::Type inType, llvm::StringRef uniqName,
41 llvm::StringRef bindcName,
42 cuf::DataAttributeAttr cudaAttr,
43 mlir::ValueRange typeparams, mlir::ValueRange shape,
44 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
45 mlir::StringAttr nameAttr =
46 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName);
47 mlir::StringAttr bindcAttr =
48 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
49 build(builder, result, wrapAllocaResultType(inType),
50 mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
51 cudaAttr);
52 result.addAttributes(attributes);
53}
54
55template <typename Op>
56static llvm::LogicalResult checkCudaAttr(Op op) {
57 if (op.getDataAttr() == cuf::DataAttribute::Device ||
58 op.getDataAttr() == cuf::DataAttribute::Managed ||
59 op.getDataAttr() == cuf::DataAttribute::Unified ||
60 op.getDataAttr() == cuf::DataAttribute::Pinned ||
61 op.getDataAttr() == cuf::DataAttribute::Shared)
62 return mlir::success();
63 return op.emitOpError()
64 << "expect device, managed, pinned or unified cuda attribute";
65}
66
67llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); }
68
69//===----------------------------------------------------------------------===//
70// FreeOp
71//===----------------------------------------------------------------------===//
72
73llvm::LogicalResult cuf::FreeOp::verify() { return checkCudaAttr(*this); }
74
75//===----------------------------------------------------------------------===//
76// AllocateOp
77//===----------------------------------------------------------------------===//
78
79template <typename OpTy>
80static llvm::LogicalResult checkStreamType(OpTy op) {
81 if (!op.getStream())
82 return mlir::success();
83 if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getStream().getType()))
84 if (!refTy.getEleTy().isInteger(64))
85 return op.emitOpError("stream is expected to be an i64 reference");
86 return mlir::success();
87}
88
89llvm::LogicalResult cuf::AllocateOp::verify() {
90 if (getPinned() && getStream())
91 return emitOpError("pinned and stream cannot appears at the same time");
92 if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
93 return emitOpError(
94 "expect box to be a reference to a class or box type value");
95 if (getSource() &&
96 !mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getSource().getType())))
97 return emitOpError(
98 "expect source to be a reference to/or a class or box type value");
99 if (getErrmsg() &&
100 !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType())))
101 return emitOpError(
102 "expect errmsg to be a reference to/or a box type value");
103 if (getErrmsg() && !getHasStat())
104 return emitOpError("expect stat attribute when errmsg is provided");
105 return checkStreamType(*this);
106}
107
108//===----------------------------------------------------------------------===//
109// DataTransferOp
110//===----------------------------------------------------------------------===//
111
112llvm::LogicalResult cuf::DataTransferOp::verify() {
113 mlir::Type srcTy = getSrc().getType();
114 mlir::Type dstTy = getDst().getType();
115 if (getShape()) {
116 if (!fir::isa_ref_type(srcTy) && !fir::isa_ref_type(dstTy))
117 return emitOpError()
118 << "shape can only be specified on data transfer with references";
119 }
120 if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
121 (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
122 (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
123 (fir::isa_box_type(srcTy) && fir::isa_ref_type(dstTy)))
124 return mlir::success();
125 if (fir::isa_trivial(srcTy) &&
126 matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
127 return mlir::success();
128
129 return emitOpError()
130 << "expect src and dst to be references or descriptors or src to "
131 "be a constant: "
132 << srcTy << " - " << dstTy;
133}
134
135//===----------------------------------------------------------------------===//
136// DeallocateOp
137//===----------------------------------------------------------------------===//
138
139llvm::LogicalResult cuf::DeallocateOp::verify() {
140 if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
141 return emitOpError(
142 "expect box to be a reference to class or box type value");
143 if (getErrmsg() &&
144 !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType())))
145 return emitOpError(
146 "expect errmsg to be a reference to/or a box type value");
147 if (getErrmsg() && !getHasStat())
148 return emitOpError("expect stat attribute when errmsg is provided");
149 return mlir::success();
150}
151
152//===----------------------------------------------------------------------===//
153// KernelLaunchOp
154//===----------------------------------------------------------------------===//
155
156llvm::LogicalResult cuf::KernelLaunchOp::verify() {
157 return checkStreamType(*this);
158}
159
160//===----------------------------------------------------------------------===//
161// KernelOp
162//===----------------------------------------------------------------------===//
163
164llvm::SmallVector<mlir::Region *> cuf::KernelOp::getLoopRegions() {
165 return {&getRegion()};
166}
167
168mlir::ParseResult parseCUFKernelValues(
169 mlir::OpAsmParser &parser,
170 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
171 llvm::SmallVectorImpl<mlir::Type> &types) {
172 if (mlir::succeeded(Result: parser.parseOptionalStar()))
173 return mlir::success();
174
175 if (mlir::succeeded(Result: parser.parseOptionalLParen())) {
176 if (mlir::failed(Result: parser.parseCommaSeparatedList(
177 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
178 if (parser.parseOperand(result&: values.emplace_back()))
179 return mlir::failure();
180 return mlir::success();
181 })))
182 return mlir::failure();
183 auto builder = parser.getBuilder();
184 for (size_t i = 0; i < values.size(); i++) {
185 types.emplace_back(Args: builder.getI32Type());
186 }
187 if (parser.parseRParen())
188 return mlir::failure();
189 } else {
190 if (parser.parseOperand(result&: values.emplace_back()))
191 return mlir::failure();
192 auto builder = parser.getBuilder();
193 types.emplace_back(Args: builder.getI32Type());
194 return mlir::success();
195 }
196 return mlir::success();
197}
198
199void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
200 mlir::ValueRange values, mlir::TypeRange types) {
201 if (values.empty())
202 p << "*";
203
204 if (values.size() > 1)
205 p << "(";
206 llvm::interleaveComma(c: values, os&: p, each_fn: [&p](mlir::Value v) { p << v; });
207 if (values.size() > 1)
208 p << ")";
209}
210
211mlir::ParseResult parseCUFKernelLoopControl(
212 mlir::OpAsmParser &parser, mlir::Region &region,
213 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
214 llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
215 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
216 llvm::SmallVectorImpl<mlir::Type> &upperboundType,
217 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
218 llvm::SmallVectorImpl<mlir::Type> &stepType) {
219
220 llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
221 if (parser.parseLParen() ||
222 parser.parseArgumentList(result&: inductionVars,
223 delimiter: mlir::OpAsmParser::Delimiter::None,
224 /*allowType=*/true) ||
225 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
226 parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(),
227 delimiter: mlir::OpAsmParser::Delimiter::None) ||
228 parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() ||
229 parser.parseKeyword(keyword: "to") || parser.parseLParen() ||
230 parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(),
231 delimiter: mlir::OpAsmParser::Delimiter::None) ||
232 parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() ||
233 parser.parseKeyword(keyword: "step") || parser.parseLParen() ||
234 parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(),
235 delimiter: mlir::OpAsmParser::Delimiter::None) ||
236 parser.parseColonTypeList(result&: stepType) || parser.parseRParen())
237 return mlir::failure();
238 return parser.parseRegion(region, arguments: inductionVars);
239}
240
241void printCUFKernelLoopControl(
242 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region &region,
243 mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
244 mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
245 mlir::ValueRange steps, mlir::TypeRange stepType) {
246 mlir::ValueRange regionArgs = region.front().getArguments();
247 if (!regionArgs.empty()) {
248 p << "(";
249 llvm::interleaveComma(
250 c: regionArgs, os&: p, each_fn: [&p](mlir::Value v) { p << v << " : " << v.getType(); });
251 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
252 << upperbound << " : " << upperboundType << ") "
253 << " step (" << steps << " : " << stepType << ") ";
254 }
255 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
256}
257
258llvm::LogicalResult cuf::KernelOp::verify() {
259 if (getLowerbound().size() != getUpperbound().size() ||
260 getLowerbound().size() != getStep().size())
261 return emitOpError(
262 "expect same number of values in lowerbound, upperbound and step");
263 auto reduceAttrs = getReduceAttrs();
264 std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0;
265 if (getReduceOperands().size() != reduceAttrsSize)
266 return emitOpError("expect same number of values in reduce operands and "
267 "reduce attributes");
268 if (reduceAttrs) {
269 for (const auto &attr : reduceAttrs.value()) {
270 if (!mlir::isa<fir::ReduceAttr>(attr))
271 return emitOpError("expect reduce attributes to be ReduceAttr");
272 }
273 }
274 return checkStreamType(*this);
275}
276
277//===----------------------------------------------------------------------===//
278// RegisterKernelOp
279//===----------------------------------------------------------------------===//
280
281mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() {
282 return getName().getRootReference();
283}
284
285mlir::StringAttr cuf::RegisterKernelOp::getKernelName() {
286 return getName().getLeafReference();
287}
288
289mlir::LogicalResult cuf::RegisterKernelOp::verify() {
290 if (getKernelName() == getKernelModuleName())
291 return emitOpError("expect a module and a kernel name");
292
293 auto mod = getOperation()->getParentOfType<mlir::ModuleOp>();
294 if (!mod)
295 return emitOpError("expect to be in a module");
296
297 mlir::SymbolTable symTab(mod);
298 auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName());
299 if (!gpuMod) {
300 // If already a gpu.binary then stop the check here.
301 if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName()))
302 return mlir::success();
303 return emitOpError("gpu module not found");
304 }
305
306 mlir::SymbolTable gpuSymTab(gpuMod);
307 if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) {
308 if (!func.isKernel())
309 return emitOpError("only kernel gpu.func can be registered");
310 return mlir::success();
311 } else if (auto func =
312 gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) {
313 if (!func->getAttrOfType<mlir::UnitAttr>(
314 mlir::gpu::GPUDialect::getKernelFuncAttrName()))
315 return emitOpError("only gpu.kernel llvm.func can be registered");
316 return mlir::success();
317 }
318 return emitOpError("device function not found");
319}
320
321//===----------------------------------------------------------------------===//
322// SharedMemoryOp
323//===----------------------------------------------------------------------===//
324
325void cuf::SharedMemoryOp::build(
326 mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType,
327 llvm::StringRef uniqName, llvm::StringRef bindcName,
328 mlir::ValueRange typeparams, mlir::ValueRange shape,
329 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
330 mlir::StringAttr nameAttr =
331 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName);
332 mlir::StringAttr bindcAttr =
333 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
334 build(builder, result, wrapAllocaResultType(inType),
335 mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
336 /*offset=*/mlir::Value{});
337 result.addAttributes(attributes);
338}
339
340//===----------------------------------------------------------------------===//
341// StreamCastOp
342//===----------------------------------------------------------------------===//
343
344llvm::LogicalResult cuf::StreamCastOp::verify() {
345 return checkStreamType(*this);
346}
347
348// Tablegen operators
349
350#define GET_OP_CLASSES
351#include "flang/Optimizer/Dialect/CUF/CUFOps.cpp.inc"
352

source code of flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp