1//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of CIR operations to LLVMIR.
10//
11//===----------------------------------------------------------------------===//
12
13#include "LowerToLLVM.h"
14
15#include <deque>
16#include <optional>
17
18#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
19#include "mlir/Dialect/DLTI/DLTI.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/IR/BuiltinDialect.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Types.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
30#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
31#include "mlir/Target/LLVMIR/Export.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "clang/CIR/Dialect/IR/CIRAttrs.h"
34#include "clang/CIR/Dialect/IR/CIRDialect.h"
35#include "clang/CIR/Dialect/Passes.h"
36#include "clang/CIR/LoweringHelpers.h"
37#include "clang/CIR/MissingFeatures.h"
38#include "clang/CIR/Passes.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/IR/Module.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/TimeProfiler.h"
43
44using namespace cir;
45using namespace llvm;
46
47namespace cir {
48namespace direct {
49
50//===----------------------------------------------------------------------===//
51// Helper Methods
52//===----------------------------------------------------------------------===//
53
54namespace {
55/// If the given type is a vector type, return the vector's element type.
56/// Otherwise return the given type unchanged.
57mlir::Type elementTypeIfVector(mlir::Type type) {
58 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
59 .Case<cir::VectorType, mlir::VectorType>(
60 [](auto p) { return p.getElementType(); })
61 .Default([](mlir::Type p) { return p; });
62}
63} // namespace
64
65/// Given a type convertor and a data layout, convert the given type to a type
66/// that is suitable for memory operations. For example, this can be used to
67/// lower cir.bool accesses to i8.
68static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
69 mlir::DataLayout const &dataLayout,
70 mlir::Type type) {
71 // TODO(cir): Handle other types similarly to clang's codegen
72 // convertTypeForMemory
73 if (isa<cir::BoolType>(type)) {
74 return mlir::IntegerType::get(type.getContext(),
75 dataLayout.getTypeSizeInBits(type));
76 }
77
78 return converter.convertType(t: type);
79}
80
81static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src,
82 mlir::IntegerType dstTy,
83 bool isSigned = false) {
84 mlir::Type srcTy = src.getType();
85 assert(mlir::isa<mlir::IntegerType>(srcTy));
86
87 unsigned srcWidth = mlir::cast<mlir::IntegerType>(srcTy).getWidth();
88 unsigned dstWidth = mlir::cast<mlir::IntegerType>(dstTy).getWidth();
89 mlir::Location loc = src.getLoc();
90
91 if (dstWidth > srcWidth && isSigned)
92 return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src);
93 if (dstWidth > srcWidth)
94 return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src);
95 if (dstWidth < srcWidth)
96 return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src);
97 return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src);
98}
99
100/// Emits the value from memory as expected by its users. Should be called when
101/// the memory represetnation of a CIR type is not equal to its scalar
102/// representation.
103static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
104 mlir::DataLayout const &dataLayout,
105 cir::LoadOp op, mlir::Value value) {
106
107 // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
108 if (auto boolTy = mlir::dyn_cast<cir::BoolType>(op.getType())) {
109 // Create a cast value from specified size in datalayout to i1
110 assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy)));
111 return createIntCast(rewriter, value, rewriter.getI1Type());
112 }
113
114 return value;
115}
116
117/// Emits a value to memory with the expected scalar type. Should be called when
118/// the memory represetnation of a CIR type is not equal to its scalar
119/// representation.
120static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
121 mlir::DataLayout const &dataLayout,
122 mlir::Type origType, mlir::Value value) {
123
124 // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
125 if (auto boolTy = mlir::dyn_cast<cir::BoolType>(origType)) {
126 // Create zext of value from i1 to i8
127 mlir::IntegerType memType =
128 rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy));
129 return createIntCast(rewriter, value, memType);
130 }
131
132 return value;
133}
134
135mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage) {
136 using CIR = cir::GlobalLinkageKind;
137 using LLVM = mlir::LLVM::Linkage;
138
139 switch (linkage) {
140 case CIR::AvailableExternallyLinkage:
141 return LLVM::AvailableExternally;
142 case CIR::CommonLinkage:
143 return LLVM::Common;
144 case CIR::ExternalLinkage:
145 return LLVM::External;
146 case CIR::ExternalWeakLinkage:
147 return LLVM::ExternWeak;
148 case CIR::InternalLinkage:
149 return LLVM::Internal;
150 case CIR::LinkOnceAnyLinkage:
151 return LLVM::Linkonce;
152 case CIR::LinkOnceODRLinkage:
153 return LLVM::LinkonceODR;
154 case CIR::PrivateLinkage:
155 return LLVM::Private;
156 case CIR::WeakAnyLinkage:
157 return LLVM::Weak;
158 case CIR::WeakODRLinkage:
159 return LLVM::WeakODR;
160 };
161 llvm_unreachable("Unknown CIR linkage type");
162}
163
164static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
165 mlir::Value llvmSrc, mlir::Type llvmDstIntTy,
166 bool isUnsigned, uint64_t cirSrcWidth,
167 uint64_t cirDstIntWidth) {
168 if (cirSrcWidth == cirDstIntWidth)
169 return llvmSrc;
170
171 auto loc = llvmSrc.getLoc();
172 if (cirSrcWidth < cirDstIntWidth) {
173 if (isUnsigned)
174 return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc);
175 return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc);
176 }
177
178 // Otherwise truncate
179 return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
180}
181
182class CIRAttrToValue {
183public:
184 CIRAttrToValue(mlir::Operation *parentOp,
185 mlir::ConversionPatternRewriter &rewriter,
186 const mlir::TypeConverter *converter)
187 : parentOp(parentOp), rewriter(rewriter), converter(converter) {}
188
189 mlir::Value visit(mlir::Attribute attr) {
190 return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
191 .Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
192 cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
193 cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
194 .Default([&](auto attrT) { return mlir::Value(); });
195 }
196
197 mlir::Value visitCirAttr(cir::IntAttr intAttr);
198 mlir::Value visitCirAttr(cir::FPAttr fltAttr);
199 mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
200 mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
201 mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
202 mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
203 mlir::Value visitCirAttr(cir::ZeroAttr attr);
204
205private:
206 mlir::Operation *parentOp;
207 mlir::ConversionPatternRewriter &rewriter;
208 const mlir::TypeConverter *converter;
209};
210
211/// Switches on the type of attribute and calls the appropriate conversion.
212mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
213 const mlir::Attribute attr,
214 mlir::ConversionPatternRewriter &rewriter,
215 const mlir::TypeConverter *converter) {
216 CIRAttrToValue valueConverter(parentOp, rewriter, converter);
217 mlir::Value value = valueConverter.visit(attr);
218 if (!value)
219 llvm_unreachable("unhandled attribute type");
220 return value;
221}
222
223/// IntAttr visitor.
224mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
225 mlir::Location loc = parentOp->getLoc();
226 return rewriter.create<mlir::LLVM::ConstantOp>(
227 loc, converter->convertType(intAttr.getType()), intAttr.getValue());
228}
229
230/// FPAttr visitor.
231mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
232 mlir::Location loc = parentOp->getLoc();
233 return rewriter.create<mlir::LLVM::ConstantOp>(
234 loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
235}
236
237/// ConstComplexAttr visitor.
238mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
239 auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
240 mlir::Type complexElemTy = complexType.getElementType();
241 mlir::Type complexElemLLVMTy = converter->convertType(t: complexElemTy);
242
243 mlir::Attribute components[2];
244 if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
245 components[0] = rewriter.getIntegerAttr(
246 complexElemLLVMTy,
247 mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
248 components[1] = rewriter.getIntegerAttr(
249 complexElemLLVMTy,
250 mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
251 } else {
252 components[0] = rewriter.getFloatAttr(
253 complexElemLLVMTy,
254 mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
255 components[1] = rewriter.getFloatAttr(
256 complexElemLLVMTy,
257 mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
258 }
259
260 mlir::Location loc = parentOp->getLoc();
261 return rewriter.create<mlir::LLVM::ConstantOp>(
262 loc, converter->convertType(complexAttr.getType()),
263 rewriter.getArrayAttr(components));
264}
265
266/// ConstPtrAttr visitor.
267mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
268 mlir::Location loc = parentOp->getLoc();
269 if (ptrAttr.isNullValue()) {
270 return rewriter.create<mlir::LLVM::ZeroOp>(
271 loc, converter->convertType(ptrAttr.getType()));
272 }
273 mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
274 mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
275 loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
276 ptrAttr.getValue().getInt());
277 return rewriter.create<mlir::LLVM::IntToPtrOp>(
278 loc, converter->convertType(ptrAttr.getType()), ptrVal);
279}
280
281// ConstArrayAttr visitor
282mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
283 mlir::Type llvmTy = converter->convertType(attr.getType());
284 mlir::Location loc = parentOp->getLoc();
285 mlir::Value result;
286
287 if (attr.hasTrailingZeros()) {
288 mlir::Type arrayTy = attr.getType();
289 result = rewriter.create<mlir::LLVM::ZeroOp>(
290 loc, converter->convertType(arrayTy));
291 } else {
292 result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
293 }
294
295 // Iteratively lower each constant element of the array.
296 if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts())) {
297 for (auto [idx, elt] : llvm::enumerate(arrayAttr)) {
298 mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>());
299 mlir::Value init = visit(elt);
300 result =
301 rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
302 }
303 } else if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) {
304 // TODO(cir): this diverges from traditional lowering. Normally the string
305 // would be a global constant that is memcopied.
306 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(strAttr.getType());
307 assert(arrayTy && "String attribute must have an array type");
308 mlir::Type eltTy = arrayTy.getElementType();
309 for (auto [idx, elt] : llvm::enumerate(strAttr)) {
310 auto init = rewriter.create<mlir::LLVM::ConstantOp>(
311 loc, converter->convertType(eltTy), elt);
312 result =
313 rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
314 }
315 } else {
316 llvm_unreachable("unexpected ConstArrayAttr elements");
317 }
318
319 return result;
320}
321
322/// ConstVectorAttr visitor.
323mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
324 const mlir::Type llvmTy = converter->convertType(attr.getType());
325 const mlir::Location loc = parentOp->getLoc();
326
327 SmallVector<mlir::Attribute> mlirValues;
328 for (const mlir::Attribute elementAttr : attr.getElts()) {
329 mlir::Attribute mlirAttr;
330 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
331 mlirAttr = rewriter.getIntegerAttr(
332 converter->convertType(intAttr.getType()), intAttr.getValue());
333 } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
334 mlirAttr = rewriter.getFloatAttr(
335 converter->convertType(floatAttr.getType()), floatAttr.getValue());
336 } else {
337 llvm_unreachable(
338 "vector constant with an element that is neither an int nor a float");
339 }
340 mlirValues.push_back(mlirAttr);
341 }
342
343 return rewriter.create<mlir::LLVM::ConstantOp>(
344 loc, llvmTy,
345 mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
346 mlirValues));
347}
348
349/// ZeroAttr visitor.
350mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
351 mlir::Location loc = parentOp->getLoc();
352 return rewriter.create<mlir::LLVM::ZeroOp>(
353 loc, converter->convertType(attr.getType()));
354}
355
356// This class handles rewriting initializer attributes for types that do not
357// require region initialization.
358class GlobalInitAttrRewriter {
359public:
360 GlobalInitAttrRewriter(mlir::Type type,
361 mlir::ConversionPatternRewriter &rewriter)
362 : llvmType(type), rewriter(rewriter) {}
363
364 mlir::Attribute visit(mlir::Attribute attr) {
365 return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
366 .Case<cir::IntAttr, cir::FPAttr, cir::BoolAttr>(
367 [&](auto attrT) { return visitCirAttr(attrT); })
368 .Default([&](auto attrT) { return mlir::Attribute(); });
369 }
370
371 mlir::Attribute visitCirAttr(cir::IntAttr attr) {
372 return rewriter.getIntegerAttr(llvmType, attr.getValue());
373 }
374
375 mlir::Attribute visitCirAttr(cir::FPAttr attr) {
376 return rewriter.getFloatAttr(llvmType, attr.getValue());
377 }
378
379 mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
380 return rewriter.getBoolAttr(value: attr.getValue());
381 }
382
383private:
384 mlir::Type llvmType;
385 mlir::ConversionPatternRewriter &rewriter;
386};
387
388// This pass requires the CIR to be in a "flat" state. All blocks in each
389// function must belong to the parent region. Once scopes and control flow
390// are implemented in CIR, a pass will be run before this one to flatten
391// the CIR and get it into the state that this pass requires.
392struct ConvertCIRToLLVMPass
393 : public mlir::PassWrapper<ConvertCIRToLLVMPass,
394 mlir::OperationPass<mlir::ModuleOp>> {
395 void getDependentDialects(mlir::DialectRegistry &registry) const override {
396 registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect,
397 mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>();
398 }
399 void runOnOperation() final;
400
401 void processCIRAttrs(mlir::ModuleOp module);
402
403 StringRef getDescription() const override {
404 return "Convert the prepared CIR dialect module to LLVM dialect";
405 }
406
407 StringRef getArgument() const override { return "cir-flat-to-llvm"; }
408};
409
410mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite(
411 cir::BrCondOp brOp, OpAdaptor adaptor,
412 mlir::ConversionPatternRewriter &rewriter) const {
413 // When ZExtOp is implemented, we'll need to check if the condition is a
414 // ZExtOp and if so, delete it if it has a single use.
415 assert(!cir::MissingFeatures::zextOp());
416
417 mlir::Value i1Condition = adaptor.getCond();
418
419 rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
420 brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
421 brOp.getDestFalse(), adaptor.getDestOperandsFalse());
422
423 return mlir::success();
424}
425
426mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const {
427 return getTypeConverter()->convertType(ty);
428}
429
430mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
431 cir::CastOp castOp, OpAdaptor adaptor,
432 mlir::ConversionPatternRewriter &rewriter) const {
433 // For arithmetic conversions, LLVM IR uses the same instruction to convert
434 // both individual scalars and entire vectors. This lowering pass handles
435 // both situations.
436
437 switch (castOp.getKind()) {
438 case cir::CastKind::array_to_ptrdecay: {
439 const auto ptrTy = mlir::cast<cir::PointerType>(castOp.getType());
440 mlir::Value sourceValue = adaptor.getSrc();
441 mlir::Type targetType = convertTy(ty: ptrTy);
442 mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout,
443 ptrTy.getPointee());
444 llvm::SmallVector<mlir::LLVM::GEPArg> offset{0};
445 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
446 castOp, targetType, elementTy, sourceValue, offset);
447 break;
448 }
449 case cir::CastKind::int_to_bool: {
450 mlir::Value llvmSrcVal = adaptor.getSrc();
451 mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>(
452 castOp.getLoc(), llvmSrcVal.getType(), 0);
453 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
454 castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt);
455 break;
456 }
457 case cir::CastKind::integral: {
458 mlir::Type srcType = castOp.getSrc().getType();
459 mlir::Type dstType = castOp.getType();
460 mlir::Value llvmSrcVal = adaptor.getSrc();
461 mlir::Type llvmDstType = getTypeConverter()->convertType(dstType);
462 cir::IntType srcIntType =
463 mlir::cast<cir::IntType>(elementTypeIfVector(srcType));
464 cir::IntType dstIntType =
465 mlir::cast<cir::IntType>(elementTypeIfVector(dstType));
466 rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType,
467 srcIntType.isUnsigned(),
468 srcIntType.getWidth(),
469 dstIntType.getWidth()));
470 break;
471 }
472 case cir::CastKind::floating: {
473 mlir::Value llvmSrcVal = adaptor.getSrc();
474 mlir::Type llvmDstTy = getTypeConverter()->convertType(castOp.getType());
475
476 mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType());
477 mlir::Type dstTy = elementTypeIfVector(castOp.getType());
478
479 if (!mlir::isa<cir::CIRFPTypeInterface>(dstTy) ||
480 !mlir::isa<cir::CIRFPTypeInterface>(srcTy))
481 return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
482
483 auto getFloatWidth = [](mlir::Type ty) -> unsigned {
484 return mlir::cast<cir::CIRFPTypeInterface>(ty).getWidth();
485 };
486
487 if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
488 rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy,
489 llvmSrcVal);
490 else
491 rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy,
492 llvmSrcVal);
493 return mlir::success();
494 }
495 case cir::CastKind::int_to_ptr: {
496 auto dstTy = mlir::cast<cir::PointerType>(castOp.getType());
497 mlir::Value llvmSrcVal = adaptor.getSrc();
498 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
499 rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy,
500 llvmSrcVal);
501 return mlir::success();
502 }
503 case cir::CastKind::ptr_to_int: {
504 auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
505 mlir::Value llvmSrcVal = adaptor.getSrc();
506 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
507 rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy,
508 llvmSrcVal);
509 return mlir::success();
510 }
511 case cir::CastKind::float_to_bool: {
512 mlir::Value llvmSrcVal = adaptor.getSrc();
513 auto kind = mlir::LLVM::FCmpPredicate::une;
514
515 // Check if float is not equal to zero.
516 auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>(
517 castOp.getLoc(), llvmSrcVal.getType(),
518 mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0));
519
520 // Extend comparison result to either bool (C++) or int (C).
521 rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(castOp, kind, llvmSrcVal,
522 zeroFloat);
523
524 return mlir::success();
525 }
526 case cir::CastKind::bool_to_int: {
527 auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
528 mlir::Value llvmSrcVal = adaptor.getSrc();
529 auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType());
530 auto llvmDstTy =
531 mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy));
532 if (llvmSrcTy.getWidth() == llvmDstTy.getWidth())
533 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
534 llvmSrcVal);
535 else
536 rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy,
537 llvmSrcVal);
538 return mlir::success();
539 }
540 case cir::CastKind::bool_to_float: {
541 mlir::Type dstTy = castOp.getType();
542 mlir::Value llvmSrcVal = adaptor.getSrc();
543 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
544 rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
545 llvmSrcVal);
546 return mlir::success();
547 }
548 case cir::CastKind::int_to_float: {
549 mlir::Type dstTy = castOp.getType();
550 mlir::Value llvmSrcVal = adaptor.getSrc();
551 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
552 if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getSrc().getType()))
553 .isSigned())
554 rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy,
555 llvmSrcVal);
556 else
557 rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
558 llvmSrcVal);
559 return mlir::success();
560 }
561 case cir::CastKind::float_to_int: {
562 mlir::Type dstTy = castOp.getType();
563 mlir::Value llvmSrcVal = adaptor.getSrc();
564 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
565 if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getType()))
566 .isSigned())
567 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy,
568 llvmSrcVal);
569 else
570 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy,
571 llvmSrcVal);
572 return mlir::success();
573 }
574 case cir::CastKind::bitcast: {
575 mlir::Type dstTy = castOp.getType();
576 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
577
578 assert(!MissingFeatures::cxxABI());
579 assert(!MissingFeatures::dataMemberType());
580
581 mlir::Value llvmSrcVal = adaptor.getSrc();
582 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
583 llvmSrcVal);
584 return mlir::success();
585 }
586 case cir::CastKind::ptr_to_bool: {
587 mlir::Value llvmSrcVal = adaptor.getSrc();
588 mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>(
589 castOp.getLoc(), llvmSrcVal.getType());
590 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
591 castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr);
592 break;
593 }
594 case cir::CastKind::address_space: {
595 mlir::Type dstTy = castOp.getType();
596 mlir::Value llvmSrcVal = adaptor.getSrc();
597 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
598 rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(castOp, llvmDstTy,
599 llvmSrcVal);
600 break;
601 }
602 case cir::CastKind::member_ptr_to_bool:
603 assert(!MissingFeatures::cxxABI());
604 assert(!MissingFeatures::methodType());
605 break;
606 default: {
607 return castOp.emitError("Unhandled cast kind: ")
608 << castOp.getKindAttrName();
609 }
610 }
611
612 return mlir::success();
613}
614
615mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
616 cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
617 mlir::ConversionPatternRewriter &rewriter) const {
618
619 const mlir::TypeConverter *tc = getTypeConverter();
620 const mlir::Type resultTy = tc->convertType(ptrStrideOp.getType());
621
622 mlir::Type elementTy =
623 convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy());
624 mlir::MLIRContext *ctx = elementTy.getContext();
625
626 // void and function types doesn't really have a layout to use in GEPs,
627 // make it i8 instead.
628 if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
629 mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
630 elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
631 mlir::IntegerType::Signless);
632 // Zero-extend, sign-extend or trunc the pointer value.
633 mlir::Value index = adaptor.getStride();
634 const unsigned width =
635 mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
636 const std::optional<std::uint64_t> layoutWidth =
637 dataLayout.getTypeIndexBitwidth(t: adaptor.getBase().getType());
638
639 mlir::Operation *indexOp = index.getDefiningOp();
640 if (indexOp && layoutWidth && width != *layoutWidth) {
641 // If the index comes from a subtraction, make sure the extension happens
642 // before it. To achieve that, look at unary minus, which already got
643 // lowered to "sub 0, x".
644 const auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
645 auto unary = dyn_cast_if_present<cir::UnaryOp>(
646 ptrStrideOp.getStride().getDefiningOp());
647 bool rewriteSub =
648 unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
649 if (rewriteSub)
650 index = indexOp->getOperand(idx: 1);
651
652 // Handle the cast
653 const auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
654 index = getLLVMIntCast(rewriter, index, llvmDstType,
655 ptrStrideOp.getStride().getType().isUnsigned(),
656 width, *layoutWidth);
657
658 // Rewrite the sub in front of extensions/trunc
659 if (rewriteSub) {
660 index = rewriter.create<mlir::LLVM::SubOp>(
661 index.getLoc(), index.getType(),
662 rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
663 index.getType(), 0),
664 index);
665 rewriter.eraseOp(op: sub);
666 }
667 }
668
669 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
670 ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
671 return mlir::success();
672}
673
674mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
675 cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
676 mlir::ConversionPatternRewriter &rewriter) const {
677 const mlir::Type resultType =
678 getTypeConverter()->convertType(baseClassOp.getType());
679 mlir::Value derivedAddr = adaptor.getDerivedAddr();
680 llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {
681 adaptor.getOffset().getZExtValue()};
682 mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
683 mlir::IntegerType::Signless);
684 if (adaptor.getOffset().getZExtValue() == 0) {
685 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(
686 baseClassOp, resultType, adaptor.getDerivedAddr());
687 return mlir::success();
688 }
689
690 if (baseClassOp.getAssumeNotNull()) {
691 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
692 baseClassOp, resultType, byteType, derivedAddr, offset);
693 } else {
694 auto loc = baseClassOp.getLoc();
695 mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>(
696 loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr,
697 rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType()));
698 mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>(
699 loc, resultType, byteType, derivedAddr, offset);
700 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull,
701 derivedAddr, adjusted);
702 }
703 return mlir::success();
704}
705
706mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite(
707 cir::AllocaOp op, OpAdaptor adaptor,
708 mlir::ConversionPatternRewriter &rewriter) const {
709 assert(!cir::MissingFeatures::opAllocaDynAllocSize());
710 mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
711 op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1);
712 mlir::Type elementTy =
713 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType());
714 mlir::Type resultTy =
715 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
716
717 assert(!cir::MissingFeatures::addressSpace());
718 assert(!cir::MissingFeatures::opAllocaAnnotations());
719
720 rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>(
721 op, resultTy, elementTy, size, op.getAlignmentAttr().getInt());
722
723 return mlir::success();
724}
725
726mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite(
727 cir::ReturnOp op, OpAdaptor adaptor,
728 mlir::ConversionPatternRewriter &rewriter) const {
729 rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, adaptor.getOperands());
730 return mlir::LogicalResult::success();
731}
732
733static mlir::LogicalResult
734rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
735 mlir::ConversionPatternRewriter &rewriter,
736 const mlir::TypeConverter *converter,
737 mlir::FlatSymbolRefAttr calleeAttr) {
738 llvm::SmallVector<mlir::Type, 8> llvmResults;
739 mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
740
741 if (converter->convertTypes(types: cirResults, results&: llvmResults).failed())
742 return mlir::failure();
743
744 assert(!cir::MissingFeatures::opCallCallConv());
745 assert(!cir::MissingFeatures::opCallSideEffect());
746
747 mlir::LLVM::LLVMFunctionType llvmFnTy;
748 if (calleeAttr) { // direct call
749 mlir::FunctionOpInterface fn =
750 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>(
751 op, calleeAttr);
752 assert(fn && "Did not find function for call");
753 llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
754 converter->convertType(fn.getFunctionType()));
755 } else { // indirect call
756 assert(!op->getOperands().empty() &&
757 "operands list must no be empty for the indirect call");
758 auto calleeTy = op->getOperands().front().getType();
759 auto calleePtrTy = cast<cir::PointerType>(calleeTy);
760 auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
761 calleeFuncTy.dump();
762 converter->convertType(calleeFuncTy).dump();
763 llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
764 converter->convertType(calleeFuncTy));
765 }
766
767 assert(!cir::MissingFeatures::opCallLandingPad());
768 assert(!cir::MissingFeatures::opCallContinueBlock());
769 assert(!cir::MissingFeatures::opCallCallConv());
770 assert(!cir::MissingFeatures::opCallSideEffect());
771
772 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmFnTy, calleeAttr,
773 callOperands);
774 return mlir::success();
775}
776
777mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
778 cir::CallOp op, OpAdaptor adaptor,
779 mlir::ConversionPatternRewriter &rewriter) const {
780 return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
781 getTypeConverter(), op.getCalleeAttr());
782}
783
784mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite(
785 cir::LoadOp op, OpAdaptor adaptor,
786 mlir::ConversionPatternRewriter &rewriter) const {
787 const mlir::Type llvmTy =
788 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
789 assert(!cir::MissingFeatures::opLoadStoreMemOrder());
790 std::optional<size_t> opAlign = op.getAlignment();
791 unsigned alignment =
792 (unsigned)opAlign.value_or(u: dataLayout.getTypeABIAlignment(t: llvmTy));
793
794 assert(!cir::MissingFeatures::lowerModeOptLevel());
795
796 // TODO: nontemporal, syncscope.
797 assert(!cir::MissingFeatures::opLoadStoreVolatile());
798 mlir::LLVM::LoadOp newLoad = rewriter.create<mlir::LLVM::LoadOp>(
799 op->getLoc(), llvmTy, adaptor.getAddr(), alignment,
800 /*volatile=*/false, /*nontemporal=*/false,
801 /*invariant=*/false, /*invariantGroup=*/false,
802 mlir::LLVM::AtomicOrdering::not_atomic);
803
804 // Convert adapted result to its original type if needed.
805 mlir::Value result =
806 emitFromMemory(rewriter, dataLayout, op, newLoad.getResult());
807 rewriter.replaceOp(op, result);
808 assert(!cir::MissingFeatures::opLoadStoreTbaa());
809 return mlir::LogicalResult::success();
810}
811
812mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite(
813 cir::StoreOp op, OpAdaptor adaptor,
814 mlir::ConversionPatternRewriter &rewriter) const {
815 assert(!cir::MissingFeatures::opLoadStoreMemOrder());
816 const mlir::Type llvmTy =
817 getTypeConverter()->convertType(op.getValue().getType());
818 std::optional<size_t> opAlign = op.getAlignment();
819 unsigned alignment =
820 (unsigned)opAlign.value_or(u: dataLayout.getTypeABIAlignment(t: llvmTy));
821
822 assert(!cir::MissingFeatures::lowerModeOptLevel());
823
824 // Convert adapted value to its memory type if needed.
825 mlir::Value value = emitToMemory(rewriter, dataLayout,
826 op.getValue().getType(), adaptor.getValue());
827 // TODO: nontemporal, syncscope.
828 assert(!cir::MissingFeatures::opLoadStoreVolatile());
829 mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>(
830 op->getLoc(), value, adaptor.getAddr(), alignment, /*volatile=*/false,
831 /*nontemporal=*/false, /*invariantGroup=*/false,
832 mlir::LLVM::AtomicOrdering::not_atomic);
833 rewriter.replaceOp(op, storeOp);
834 assert(!cir::MissingFeatures::opLoadStoreTbaa());
835 return mlir::LogicalResult::success();
836}
837
838bool hasTrailingZeros(cir::ConstArrayAttr attr) {
839 auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts());
840 return attr.hasTrailingZeros() ||
841 (array && std::count_if(array.begin(), array.end(), [](auto elt) {
842 auto ar = dyn_cast<cir::ConstArrayAttr>(elt);
843 return ar && hasTrailingZeros(ar);
844 }));
845}
846
847mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
848 cir::ConstantOp op, OpAdaptor adaptor,
849 mlir::ConversionPatternRewriter &rewriter) const {
850 mlir::Attribute attr = op.getValue();
851
852 if (mlir::isa<mlir::IntegerType>(op.getType())) {
853 // Verified cir.const operations cannot actually be of these types, but the
854 // lowering pass may generate temporary cir.const operations with these
855 // types. This is OK since MLIR allows unverified operations to be alive
856 // during a pass as long as they don't live past the end of the pass.
857 attr = op.getValue();
858 } else if (mlir::isa<cir::BoolType>(op.getType())) {
859 int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
860 attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
861 value);
862 } else if (mlir::isa<cir::IntType>(op.getType())) {
863 assert(!cir::MissingFeatures::opGlobalViewAttr());
864
865 attr = rewriter.getIntegerAttr(
866 typeConverter->convertType(op.getType()),
867 mlir::cast<cir::IntAttr>(op.getValue()).getValue());
868 } else if (mlir::isa<cir::CIRFPTypeInterface>(op.getType())) {
869 attr = rewriter.getFloatAttr(
870 typeConverter->convertType(op.getType()),
871 mlir::cast<cir::FPAttr>(op.getValue()).getValue());
872 } else if (mlir::isa<cir::PointerType>(op.getType())) {
873 // Optimize with dedicated LLVM op for null pointers.
874 if (mlir::isa<cir::ConstPtrAttr>(op.getValue())) {
875 if (mlir::cast<cir::ConstPtrAttr>(op.getValue()).isNullValue()) {
876 rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>(
877 op, typeConverter->convertType(op.getType()));
878 return mlir::success();
879 }
880 }
881 assert(!cir::MissingFeatures::opGlobalViewAttr());
882 attr = op.getValue();
883 } else if (const auto arrTy = mlir::dyn_cast<cir::ArrayType>(op.getType())) {
884 const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(op.getValue());
885 if (!constArr && !isa<cir::ZeroAttr, cir::UndefAttr>(op.getValue()))
886 return op.emitError() << "array does not have a constant initializer";
887
888 std::optional<mlir::Attribute> denseAttr;
889 if (constArr && hasTrailingZeros(constArr)) {
890 const mlir::Value newOp =
891 lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
892 rewriter.replaceOp(op, newOp);
893 return mlir::success();
894 } else if (constArr &&
895 (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
896 attr = denseAttr.value();
897 } else {
898 const mlir::Value initVal =
899 lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter);
900 rewriter.replaceAllUsesWith(op, initVal);
901 rewriter.eraseOp(op: op);
902 return mlir::success();
903 }
904 } else {
905 return op.emitError() << "unsupported constant type " << op.getType();
906 }
907
908 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
909 op, getTypeConverter()->convertType(op.getType()), attr);
910
911 return mlir::success();
912}
913
914/// Convert the `cir.func` attributes to `llvm.func` attributes.
915/// Only retain those attributes that are not constructed by
916/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
917/// argument attributes.
918void CIRToLLVMFuncOpLowering::lowerFuncAttributes(
919 cir::FuncOp func, bool filterArgAndResAttrs,
920 SmallVectorImpl<mlir::NamedAttribute> &result) const {
921 assert(!cir::MissingFeatures::opFuncCallingConv());
922 for (mlir::NamedAttribute attr : func->getAttrs()) {
923 if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
924 attr.getName() == func.getFunctionTypeAttrName() ||
925 attr.getName() == getLinkageAttrNameString() ||
926 (filterArgAndResAttrs &&
927 (attr.getName() == func.getArgAttrsAttrName() ||
928 attr.getName() == func.getResAttrsAttrName())))
929 continue;
930
931 assert(!cir::MissingFeatures::opFuncExtraAttrs());
932 result.push_back(attr);
933 }
934}
935
936mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite(
937 cir::FuncOp op, OpAdaptor adaptor,
938 mlir::ConversionPatternRewriter &rewriter) const {
939
940 cir::FuncType fnType = op.getFunctionType();
941 assert(!cir::MissingFeatures::opFuncDsolocal());
942 bool isDsoLocal = false;
943 mlir::TypeConverter::SignatureConversion signatureConversion(
944 fnType.getNumInputs());
945
946 for (const auto &argType : llvm::enumerate(fnType.getInputs())) {
947 mlir::Type convertedType = typeConverter->convertType(argType.value());
948 if (!convertedType)
949 return mlir::failure();
950 signatureConversion.addInputs(argType.index(), convertedType);
951 }
952
953 mlir::Type resultType =
954 getTypeConverter()->convertType(fnType.getReturnType());
955
956 // Create the LLVM function operation.
957 mlir::Type llvmFnTy = mlir::LLVM::LLVMFunctionType::get(
958 resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()),
959 signatureConversion.getConvertedTypes(),
960 /*isVarArg=*/fnType.isVarArg());
961 // LLVMFuncOp expects a single FileLine Location instead of a fused
962 // location.
963 mlir::Location loc = op.getLoc();
964 if (mlir::FusedLoc fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc))
965 loc = fusedLoc.getLocations()[0];
966 assert((mlir::isa<mlir::FileLineColLoc>(loc) ||
967 mlir::isa<mlir::UnknownLoc>(loc)) &&
968 "expected single location or unknown location here");
969
970 assert(!cir::MissingFeatures::opFuncLinkage());
971 mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External;
972 assert(!cir::MissingFeatures::opFuncCallingConv());
973 mlir::LLVM::CConv cconv = mlir::LLVM::CConv::C;
974 SmallVector<mlir::NamedAttribute, 4> attributes;
975 lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
976
977 mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
978 loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv,
979 mlir::SymbolRefAttr(), attributes);
980
981 assert(!cir::MissingFeatures::opFuncVisibility());
982
983 rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
984 if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter,
985 &signatureConversion)))
986 return mlir::failure();
987
988 rewriter.eraseOp(op: op);
989
990 return mlir::LogicalResult::success();
991}
992
993mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite(
994 cir::GetGlobalOp op, OpAdaptor adaptor,
995 mlir::ConversionPatternRewriter &rewriter) const {
996 // FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
997 // CIRGen should mitigate this and not emit the get_global.
998 if (op->getUses().empty()) {
999 rewriter.eraseOp(op: op);
1000 return mlir::success();
1001 }
1002
1003 mlir::Type type = getTypeConverter()->convertType(op.getType());
1004 mlir::Operation *newop =
1005 rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName());
1006
1007 assert(!cir::MissingFeatures::opGlobalThreadLocal());
1008
1009 rewriter.replaceOp(op, newop);
1010 return mlir::success();
1011}
1012
1013/// Replace CIR global with a region initialized LLVM global and update
1014/// insertion point to the end of the initializer block.
1015void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
1016 cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
1017 const mlir::Type llvmType =
1018 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
1019
1020 // FIXME: These default values are placeholders until the the equivalent
1021 // attributes are available on cir.global ops. This duplicates code
1022 // in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go
1023 // away when the placeholders are no longer needed.
1024 assert(!cir::MissingFeatures::opGlobalConstant());
1025 const bool isConst = false;
1026 assert(!cir::MissingFeatures::addressSpace());
1027 const unsigned addrSpace = 0;
1028 const bool isDsoLocal = op.getDsolocal();
1029 assert(!cir::MissingFeatures::opGlobalThreadLocal());
1030 const bool isThreadLocal = false;
1031 const uint64_t alignment = op.getAlignment().value_or(0);
1032 const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1033 const StringRef symbol = op.getSymName();
1034 mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1035
1036 SmallVector<mlir::NamedAttribute> attributes;
1037 mlir::LLVM::GlobalOp newGlobalOp =
1038 rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1039 op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace,
1040 isDsoLocal, isThreadLocal, comdatAttr, attributes);
1041 newGlobalOp.getRegion().emplaceBlock();
1042 rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
1043}
1044
1045mlir::LogicalResult
1046CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
1047 cir::GlobalOp op, mlir::Attribute init,
1048 mlir::ConversionPatternRewriter &rewriter) const {
1049 // TODO: Generalize this handling when more types are needed here.
1050 assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
1051 cir::ConstComplexAttr, cir::ZeroAttr>(init)));
1052
1053 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
1054 // should be updated. For now, we use a custom op to initialize globals
1055 // to the appropriate value.
1056 const mlir::Location loc = op.getLoc();
1057 setupRegionInitializedLLVMGlobalOp(op, rewriter);
1058 CIRAttrToValue valueConverter(op, rewriter, typeConverter);
1059 mlir::Value value = valueConverter.visit(attr: init);
1060 rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
1061 return mlir::success();
1062}
1063
1064mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
1065 cir::GlobalOp op, OpAdaptor adaptor,
1066 mlir::ConversionPatternRewriter &rewriter) const {
1067
1068 std::optional<mlir::Attribute> init = op.getInitialValue();
1069
1070 // Fetch required values to create LLVM op.
1071 const mlir::Type cirSymType = op.getSymType();
1072
1073 // This is the LLVM dialect type.
1074 const mlir::Type llvmType =
1075 convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType);
1076 // FIXME: These default values are placeholders until the the equivalent
1077 // attributes are available on cir.global ops.
1078 assert(!cir::MissingFeatures::opGlobalConstant());
1079 const bool isConst = false;
1080 assert(!cir::MissingFeatures::addressSpace());
1081 const unsigned addrSpace = 0;
1082 const bool isDsoLocal = op.getDsolocal();
1083 assert(!cir::MissingFeatures::opGlobalThreadLocal());
1084 const bool isThreadLocal = false;
1085 const uint64_t alignment = op.getAlignment().value_or(0);
1086 const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1087 const StringRef symbol = op.getSymName();
1088 SmallVector<mlir::NamedAttribute> attributes;
1089 mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1090
1091 if (init.has_value()) {
1092 if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) {
1093 GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
1094 init = initRewriter.visit(attr: init.value());
1095 // If initRewriter returned a null attribute, init will have a value but
1096 // the value will be null. If that happens, initRewriter didn't handle the
1097 // attribute type. It probably needs to be added to
1098 // GlobalInitAttrRewriter.
1099 if (!init.value()) {
1100 op.emitError() << "unsupported initializer '" << init.value() << "'";
1101 return mlir::failure();
1102 }
1103 } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1104 cir::ConstPtrAttr, cir::ConstComplexAttr,
1105 cir::ZeroAttr>(init.value())) {
1106 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
1107 // should be updated. For now, we use a custom op to initialize globals
1108 // to the appropriate value.
1109 return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
1110 } else {
1111 // We will only get here if new initializer types are added and this
1112 // code is not updated to handle them.
1113 op.emitError() << "unsupported initializer '" << init.value() << "'";
1114 return mlir::failure();
1115 }
1116 }
1117
1118 // Rewrite op.
1119 rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1120 op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()),
1121 alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes);
1122 return mlir::success();
1123}
1124
1125mlir::SymbolRefAttr
1126CIRToLLVMGlobalOpLowering::getComdatAttr(cir::GlobalOp &op,
1127 mlir::OpBuilder &builder) const {
1128 if (!op.getComdat())
1129 return mlir::SymbolRefAttr{};
1130
1131 mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
1132 mlir::OpBuilder::InsertionGuard guard(builder);
1133 StringRef comdatName("__llvm_comdat_globals");
1134 if (!comdatOp) {
1135 builder.setInsertionPointToStart(module.getBody());
1136 comdatOp =
1137 builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
1138 }
1139
1140 builder.setInsertionPointToStart(&comdatOp.getBody().back());
1141 auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>(
1142 comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
1143 return mlir::SymbolRefAttr::get(
1144 builder.getContext(), comdatName,
1145 mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
1146}
1147
1148mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
1149 cir::SwitchFlatOp op, OpAdaptor adaptor,
1150 mlir::ConversionPatternRewriter &rewriter) const {
1151
1152 llvm::SmallVector<mlir::APInt, 8> caseValues;
1153 for (mlir::Attribute val : op.getCaseValues()) {
1154 auto intAttr = cast<cir::IntAttr>(val);
1155 caseValues.push_back(intAttr.getValue());
1156 }
1157
1158 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1159 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1160
1161 for (mlir::Block *x : op.getCaseDestinations())
1162 caseDestinations.push_back(x);
1163
1164 for (mlir::OperandRange x : op.getCaseOperands())
1165 caseOperands.push_back(x);
1166
1167 // Set switch op to branch to the newly created blocks.
1168 rewriter.setInsertionPoint(op);
1169 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
1170 op, adaptor.getCondition(), op.getDefaultDestination(),
1171 op.getDefaultOperands(), caseValues, caseDestinations, caseOperands);
1172 return mlir::success();
1173}
1174
1175mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
1176 cir::UnaryOp op, OpAdaptor adaptor,
1177 mlir::ConversionPatternRewriter &rewriter) const {
1178 assert(op.getType() == op.getInput().getType() &&
1179 "Unary operation's operand type and result type are different");
1180 mlir::Type type = op.getType();
1181 mlir::Type elementType = elementTypeIfVector(type);
1182 bool isVector = mlir::isa<cir::VectorType>(type);
1183 mlir::Type llvmType = getTypeConverter()->convertType(type);
1184 mlir::Location loc = op.getLoc();
1185
1186 // Integer unary operations: + - ~ ++ --
1187 if (mlir::isa<cir::IntType>(elementType)) {
1188 mlir::LLVM::IntegerOverflowFlags maybeNSW =
1189 op.getNoSignedWrap() ? mlir::LLVM::IntegerOverflowFlags::nsw
1190 : mlir::LLVM::IntegerOverflowFlags::none;
1191 switch (op.getKind()) {
1192 case cir::UnaryOpKind::Inc: {
1193 assert(!isVector && "++ not allowed on vector types");
1194 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1195 rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
1196 op, llvmType, adaptor.getInput(), one, maybeNSW);
1197 return mlir::success();
1198 }
1199 case cir::UnaryOpKind::Dec: {
1200 assert(!isVector && "-- not allowed on vector types");
1201 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1202 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(),
1203 one, maybeNSW);
1204 return mlir::success();
1205 }
1206 case cir::UnaryOpKind::Plus:
1207 rewriter.replaceOp(op, adaptor.getInput());
1208 return mlir::success();
1209 case cir::UnaryOpKind::Minus: {
1210 mlir::Value zero;
1211 if (isVector)
1212 zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
1213 else
1214 zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
1215 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
1216 op, zero, adaptor.getInput(), maybeNSW);
1217 return mlir::success();
1218 }
1219 case cir::UnaryOpKind::Not: {
1220 // bit-wise compliment operator, implemented as an XOR with -1.
1221 mlir::Value minusOne;
1222 if (isVector) {
1223 const uint64_t numElements =
1224 mlir::dyn_cast<cir::VectorType>(type).getSize();
1225 std::vector<int32_t> values(numElements, -1);
1226 mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values);
1227 minusOne =
1228 rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
1229 } else {
1230 minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
1231 }
1232 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1233 minusOne);
1234 return mlir::success();
1235 }
1236 }
1237 llvm_unreachable("Unexpected unary op for int");
1238 }
1239
1240 // Floating point unary operations: + - ++ --
1241 if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) {
1242 switch (op.getKind()) {
1243 case cir::UnaryOpKind::Inc: {
1244 assert(!isVector && "++ not allowed on vector types");
1245 mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1246 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
1247 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one,
1248 adaptor.getInput());
1249 return mlir::success();
1250 }
1251 case cir::UnaryOpKind::Dec: {
1252 assert(!isVector && "-- not allowed on vector types");
1253 mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1254 loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0));
1255 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne,
1256 adaptor.getInput());
1257 return mlir::success();
1258 }
1259 case cir::UnaryOpKind::Plus:
1260 rewriter.replaceOp(op, adaptor.getInput());
1261 return mlir::success();
1262 case cir::UnaryOpKind::Minus:
1263 rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType,
1264 adaptor.getInput());
1265 return mlir::success();
1266 case cir::UnaryOpKind::Not:
1267 return op.emitError() << "Unary not is invalid for floating-point types";
1268 }
1269 llvm_unreachable("Unexpected unary op for float");
1270 }
1271
1272 // Boolean unary operations: ! only. (For all others, the operand has
1273 // already been promoted to int.)
1274 if (mlir::isa<cir::BoolType>(elementType)) {
1275 switch (op.getKind()) {
1276 case cir::UnaryOpKind::Inc:
1277 case cir::UnaryOpKind::Dec:
1278 case cir::UnaryOpKind::Plus:
1279 case cir::UnaryOpKind::Minus:
1280 // Some of these are allowed in source code, but we shouldn't get here
1281 // with a boolean type.
1282 return op.emitError() << "Unsupported unary operation on boolean type";
1283 case cir::UnaryOpKind::Not: {
1284 assert(!isVector && "NYI: op! on vector mask");
1285 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1286 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1287 one);
1288 return mlir::success();
1289 }
1290 }
1291 llvm_unreachable("Unexpected unary op for bool");
1292 }
1293
1294 // Pointer unary operations: + only. (++ and -- of pointers are implemented
1295 // with cir.ptr_stride, not cir.unary.)
1296 if (mlir::isa<cir::PointerType>(elementType)) {
1297 return op.emitError()
1298 << "Unary operation on pointer types is not yet implemented";
1299 }
1300
1301 return op.emitError() << "Unary operation has unsupported type: "
1302 << elementType;
1303}
1304
1305mlir::LLVM::IntegerOverflowFlags
1306CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
1307 if (op.getNoUnsignedWrap())
1308 return mlir::LLVM::IntegerOverflowFlags::nuw;
1309
1310 if (op.getNoSignedWrap())
1311 return mlir::LLVM::IntegerOverflowFlags::nsw;
1312
1313 return mlir::LLVM::IntegerOverflowFlags::none;
1314}
1315
1316static bool isIntTypeUnsigned(mlir::Type type) {
1317 // TODO: Ideally, we should only need to check cir::IntType here.
1318 return mlir::isa<cir::IntType>(type)
1319 ? mlir::cast<cir::IntType>(type).isUnsigned()
1320 : mlir::cast<mlir::IntegerType>(type).isUnsigned();
1321}
1322
1323mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
1324 cir::BinOp op, OpAdaptor adaptor,
1325 mlir::ConversionPatternRewriter &rewriter) const {
1326 if (adaptor.getLhs().getType() != adaptor.getRhs().getType())
1327 return op.emitError() << "inconsistent operands' types not supported yet";
1328
1329 mlir::Type type = op.getRhs().getType();
1330 if (!mlir::isa<cir::IntType, cir::BoolType, cir::CIRFPTypeInterface,
1331 mlir::IntegerType, cir::VectorType>(type))
1332 return op.emitError() << "operand type not supported yet";
1333
1334 const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1335 const mlir::Type llvmEltTy = elementTypeIfVector(type: llvmTy);
1336
1337 const mlir::Value rhs = adaptor.getRhs();
1338 const mlir::Value lhs = adaptor.getLhs();
1339 type = elementTypeIfVector(type);
1340
1341 switch (op.getKind()) {
1342 case cir::BinOpKind::Add:
1343 if (mlir::isa<mlir::IntegerType>(Val: llvmEltTy)) {
1344 if (op.getSaturated()) {
1345 if (isIntTypeUnsigned(type)) {
1346 rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs);
1347 break;
1348 }
1349 rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs);
1350 break;
1351 }
1352 rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
1353 getIntOverflowFlag(op));
1354 } else {
1355 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs);
1356 }
1357 break;
1358 case cir::BinOpKind::Sub:
1359 if (mlir::isa<mlir::IntegerType>(Val: llvmEltTy)) {
1360 if (op.getSaturated()) {
1361 if (isIntTypeUnsigned(type)) {
1362 rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs);
1363 break;
1364 }
1365 rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs);
1366 break;
1367 }
1368 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
1369 getIntOverflowFlag(op));
1370 } else {
1371 rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs);
1372 }
1373 break;
1374 case cir::BinOpKind::Mul:
1375 if (mlir::isa<mlir::IntegerType>(llvmEltTy))
1376 rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
1377 getIntOverflowFlag(op));
1378 else
1379 rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs);
1380 break;
1381 case cir::BinOpKind::Div:
1382 if (mlir::isa<mlir::IntegerType>(Val: llvmEltTy)) {
1383 auto isUnsigned = isIntTypeUnsigned(type);
1384 if (isUnsigned)
1385 rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs);
1386 else
1387 rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs);
1388 } else {
1389 rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs);
1390 }
1391 break;
1392 case cir::BinOpKind::Rem:
1393 if (mlir::isa<mlir::IntegerType>(Val: llvmEltTy)) {
1394 auto isUnsigned = isIntTypeUnsigned(type);
1395 if (isUnsigned)
1396 rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs);
1397 else
1398 rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs);
1399 } else {
1400 rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs);
1401 }
1402 break;
1403 case cir::BinOpKind::And:
1404 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs);
1405 break;
1406 case cir::BinOpKind::Or:
1407 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs);
1408 break;
1409 case cir::BinOpKind::Xor:
1410 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
1411 break;
1412 case cir::BinOpKind::Max:
1413 if (mlir::isa<mlir::IntegerType>(Val: llvmEltTy)) {
1414 auto isUnsigned = isIntTypeUnsigned(type);
1415 if (isUnsigned)
1416 rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs);
1417 else
1418 rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs);
1419 }
1420 break;
1421 }
1422 return mlir::LogicalResult::success();
1423}
1424
1425/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
1426static mlir::LLVM::ICmpPredicate
1427convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
1428 using CIR = cir::CmpOpKind;
1429 using LLVMICmp = mlir::LLVM::ICmpPredicate;
1430 switch (kind) {
1431 case CIR::eq:
1432 return LLVMICmp::eq;
1433 case CIR::ne:
1434 return LLVMICmp::ne;
1435 case CIR::lt:
1436 return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
1437 case CIR::le:
1438 return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
1439 case CIR::gt:
1440 return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
1441 case CIR::ge:
1442 return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
1443 }
1444 llvm_unreachable("Unknown CmpOpKind");
1445}
1446
1447/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
1448/// kind.
1449static mlir::LLVM::FCmpPredicate
1450convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
1451 using CIR = cir::CmpOpKind;
1452 using LLVMFCmp = mlir::LLVM::FCmpPredicate;
1453 switch (kind) {
1454 case CIR::eq:
1455 return LLVMFCmp::oeq;
1456 case CIR::ne:
1457 return LLVMFCmp::une;
1458 case CIR::lt:
1459 return LLVMFCmp::olt;
1460 case CIR::le:
1461 return LLVMFCmp::ole;
1462 case CIR::gt:
1463 return LLVMFCmp::ogt;
1464 case CIR::ge:
1465 return LLVMFCmp::oge;
1466 }
1467 llvm_unreachable("Unknown CmpOpKind");
1468}
1469
1470mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
1471 cir::CmpOp cmpOp, OpAdaptor adaptor,
1472 mlir::ConversionPatternRewriter &rewriter) const {
1473 mlir::Type type = cmpOp.getLhs().getType();
1474
1475 assert(!cir::MissingFeatures::dataMemberType());
1476 assert(!cir::MissingFeatures::methodType());
1477
1478 // Lower to LLVM comparison op.
1479 if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
1480 bool isSigned = mlir::isa<cir::IntType>(type)
1481 ? mlir::cast<cir::IntType>(type).isSigned()
1482 : mlir::cast<mlir::IntegerType>(type).isSigned();
1483 mlir::LLVM::ICmpPredicate kind =
1484 convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
1485 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1486 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1487 } else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1488 mlir::LLVM::ICmpPredicate kind =
1489 convertCmpKindToICmpPredicate(cmpOp.getKind(),
1490 /* isSigned=*/false);
1491 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1492 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1493 } else if (mlir::isa<cir::CIRFPTypeInterface>(type)) {
1494 mlir::LLVM::FCmpPredicate kind =
1495 convertCmpKindToFCmpPredicate(cmpOp.getKind());
1496 rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
1497 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1498 } else {
1499 return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1500 }
1501
1502 return mlir::success();
1503}
1504
1505mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1506 cir::ShiftOp op, OpAdaptor adaptor,
1507 mlir::ConversionPatternRewriter &rewriter) const {
1508 assert((op.getValue().getType() == op.getType()) &&
1509 "inconsistent operands' types NYI");
1510
1511 const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1512 mlir::Value amt = adaptor.getAmount();
1513 mlir::Value val = adaptor.getValue();
1514
1515 auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1516 bool isUnsigned;
1517 if (cirAmtTy) {
1518 auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType());
1519 isUnsigned = cirValTy.isUnsigned();
1520
1521 // Ensure shift amount is the same type as the value. Some undefined
1522 // behavior might occur in the casts below as per [C99 6.5.7.3].
1523 // Vector type shift amount needs no cast as type consistency is expected to
1524 // be already be enforced at CIRGen.
1525 if (cirAmtTy)
1526 amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(),
1527 cirValTy.getWidth());
1528 } else {
1529 auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType());
1530 isUnsigned =
1531 mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned();
1532 }
1533
1534 // Lower to the proper LLVM shift operation.
1535 if (op.getIsShiftleft()) {
1536 rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1537 return mlir::success();
1538 }
1539
1540 if (isUnsigned)
1541 rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1542 else
1543 rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1544 return mlir::success();
1545}
1546
1547mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1548 cir::SelectOp op, OpAdaptor adaptor,
1549 mlir::ConversionPatternRewriter &rewriter) const {
1550 auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
1551 auto definingOp =
1552 mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1553 if (!definingOp)
1554 return {};
1555
1556 auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1557 if (!constValue)
1558 return {};
1559
1560 return constValue;
1561 };
1562
1563 // Two special cases in the LLVMIR codegen of select op:
1564 // - select %0, %1, false => and %0, %1
1565 // - select %0, true, %1 => or %0, %1
1566 if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1567 cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
1568 cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
1569 if (falseValue && !falseValue.getValue()) {
1570 // select %0, %1, false => and %0, %1
1571 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1572 adaptor.getTrueValue());
1573 return mlir::success();
1574 }
1575 if (trueValue && trueValue.getValue()) {
1576 // select %0, true, %1 => or %0, %1
1577 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1578 adaptor.getFalseValue());
1579 return mlir::success();
1580 }
1581 }
1582
1583 mlir::Value llvmCondition = adaptor.getCondition();
1584 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1585 op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1586
1587 return mlir::success();
1588}
1589
1590static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
1591 mlir::DataLayout &dataLayout) {
1592 converter.addConversion(callback: [&](cir::PointerType type) -> mlir::Type {
1593 // Drop pointee type since LLVM dialect only allows opaque pointers.
1594 assert(!cir::MissingFeatures::addressSpace());
1595 unsigned targetAS = 0;
1596
1597 return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
1598 });
1599 converter.addConversion(callback: [&](cir::ArrayType type) -> mlir::Type {
1600 mlir::Type ty =
1601 convertTypeForMemory(converter, dataLayout, type.getElementType());
1602 return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
1603 });
1604 converter.addConversion(callback: [&](cir::VectorType type) -> mlir::Type {
1605 const mlir::Type ty = converter.convertType(type.getElementType());
1606 return mlir::VectorType::get(type.getSize(), ty);
1607 });
1608 converter.addConversion(callback: [&](cir::BoolType type) -> mlir::Type {
1609 return mlir::IntegerType::get(type.getContext(), 1,
1610 mlir::IntegerType::Signless);
1611 });
1612 converter.addConversion(callback: [&](cir::IntType type) -> mlir::Type {
1613 // LLVM doesn't work with signed types, so we drop the CIR signs here.
1614 return mlir::IntegerType::get(type.getContext(), type.getWidth());
1615 });
1616 converter.addConversion(callback: [&](cir::SingleType type) -> mlir::Type {
1617 return mlir::Float32Type::get(type.getContext());
1618 });
1619 converter.addConversion(callback: [&](cir::DoubleType type) -> mlir::Type {
1620 return mlir::Float64Type::get(type.getContext());
1621 });
1622 converter.addConversion(callback: [&](cir::FP80Type type) -> mlir::Type {
1623 return mlir::Float80Type::get(type.getContext());
1624 });
1625 converter.addConversion(callback: [&](cir::FP128Type type) -> mlir::Type {
1626 return mlir::Float128Type::get(type.getContext());
1627 });
1628 converter.addConversion(callback: [&](cir::LongDoubleType type) -> mlir::Type {
1629 return converter.convertType(type.getUnderlying());
1630 });
1631 converter.addConversion(callback: [&](cir::FP16Type type) -> mlir::Type {
1632 return mlir::Float16Type::get(type.getContext());
1633 });
1634 converter.addConversion(callback: [&](cir::BF16Type type) -> mlir::Type {
1635 return mlir::BFloat16Type::get(type.getContext());
1636 });
1637 converter.addConversion(callback: [&](cir::ComplexType type) -> mlir::Type {
1638 // A complex type is lowered to an LLVM struct that contains the real and
1639 // imaginary part as data fields.
1640 mlir::Type elementTy = converter.convertType(type.getElementType());
1641 mlir::Type structFields[2] = {elementTy, elementTy};
1642 return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
1643 structFields);
1644 });
1645 converter.addConversion(callback: [&](cir::FuncType type) -> std::optional<mlir::Type> {
1646 auto result = converter.convertType(type.getReturnType());
1647 llvm::SmallVector<mlir::Type> arguments;
1648 arguments.reserve(N: type.getNumInputs());
1649 if (converter.convertTypes(types: type.getInputs(), results&: arguments).failed())
1650 return std::nullopt;
1651 auto varArg = type.isVarArg();
1652 return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
1653 });
1654 converter.addConversion(callback: [&](cir::RecordType type) -> mlir::Type {
1655 // Convert struct members.
1656 llvm::SmallVector<mlir::Type> llvmMembers;
1657 switch (type.getKind()) {
1658 case cir::RecordType::Class:
1659 case cir::RecordType::Struct:
1660 for (mlir::Type ty : type.getMembers())
1661 llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty));
1662 break;
1663 // Unions are lowered as only the largest member.
1664 case cir::RecordType::Union:
1665 if (auto largestMember = type.getLargestMember(dataLayout))
1666 llvmMembers.push_back(
1667 Elt: convertTypeForMemory(converter, dataLayout, largestMember));
1668 if (type.getPadded()) {
1669 auto last = *type.getMembers().rbegin();
1670 llvmMembers.push_back(
1671 Elt: convertTypeForMemory(converter, dataLayout, last));
1672 }
1673 break;
1674 }
1675
1676 // Record has a name: lower as an identified record.
1677 mlir::LLVM::LLVMStructType llvmStruct;
1678 if (type.getName()) {
1679 llvmStruct = mlir::LLVM::LLVMStructType::getIdentified(
1680 type.getContext(), type.getPrefixedName());
1681 if (llvmStruct.setBody(llvmMembers, type.getPacked()).failed())
1682 llvm_unreachable("Failed to set body of record");
1683 } else { // Record has no name: lower as literal record.
1684 llvmStruct = mlir::LLVM::LLVMStructType::getLiteral(
1685 type.getContext(), llvmMembers, type.getPacked());
1686 }
1687
1688 return llvmStruct;
1689 });
1690}
1691
1692// The applyPartialConversion function traverses blocks in the dominance order,
1693// so it does not lower and operations that are not reachachable from the
1694// operations passed in as arguments. Since we do need to lower such code in
1695// order to avoid verification errors occur, we cannot just pass the module op
1696// to applyPartialConversion. We must build a set of unreachable ops and
1697// explicitly add them, along with the module, to the vector we pass to
1698// applyPartialConversion.
1699//
1700// For instance, this CIR code:
1701//
1702// cir.func @foo(%arg0: !s32i) -> !s32i {
1703// %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
1704// cir.if %4 {
1705// %5 = cir.const #cir.int<1> : !s32i
1706// cir.return %5 : !s32i
1707// } else {
1708// %5 = cir.const #cir.int<0> : !s32i
1709// cir.return %5 : !s32i
1710// }
1711// cir.return %arg0 : !s32i
1712// }
1713//
1714// contains an unreachable return operation (the last one). After the flattening
1715// pass it will be placed into the unreachable block. The possible error
1716// after the lowering pass is: error: 'cir.return' op expects parent op to be
1717// one of 'cir.func, cir.scope, cir.if ... The reason that this operation was
1718// not lowered and the new parent is llvm.func.
1719//
1720// In the future we may want to get rid of this function and use a DCE pass or
1721// something similar. But for now we need to guarantee the absence of the
1722// dialect verification errors.
1723static void collectUnreachable(mlir::Operation *parent,
1724 llvm::SmallVector<mlir::Operation *> &ops) {
1725
1726 llvm::SmallVector<mlir::Block *> unreachableBlocks;
1727 parent->walk(callback: [&](mlir::Block *blk) { // check
1728 if (blk->hasNoPredecessors() && !blk->isEntryBlock())
1729 unreachableBlocks.push_back(Elt: blk);
1730 });
1731
1732 std::set<mlir::Block *> visited;
1733 for (mlir::Block *root : unreachableBlocks) {
1734 // We create a work list for each unreachable block.
1735 // Thus we traverse operations in some order.
1736 std::deque<mlir::Block *> workList;
1737 workList.push_back(x: root);
1738
1739 while (!workList.empty()) {
1740 mlir::Block *blk = workList.back();
1741 workList.pop_back();
1742 if (visited.count(x: blk))
1743 continue;
1744 visited.emplace(args&: blk);
1745
1746 for (mlir::Operation &op : *blk)
1747 ops.push_back(Elt: &op);
1748
1749 for (mlir::Block *succ : blk->getSuccessors())
1750 workList.push_back(x: succ);
1751 }
1752 }
1753}
1754
1755void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) {
1756 // Lower the module attributes to LLVM equivalents.
1757 if (mlir::Attribute tripleAttr =
1758 module->getAttr(cir::CIRDialect::getTripleAttrName()))
1759 module->setAttr(mlir::LLVM::LLVMDialect::getTargetTripleAttrName(),
1760 tripleAttr);
1761}
1762
1763void ConvertCIRToLLVMPass::runOnOperation() {
1764 llvm::TimeTraceScope scope("Convert CIR to LLVM Pass");
1765
1766 mlir::ModuleOp module = getOperation();
1767 mlir::DataLayout dl(module);
1768 mlir::LLVMTypeConverter converter(&getContext());
1769 prepareTypeConverter(converter, dataLayout&: dl);
1770
1771 mlir::RewritePatternSet patterns(&getContext());
1772
1773 patterns.add<CIRToLLVMReturnOpLowering>(arg: patterns.getContext());
1774 // This could currently be merged with the group below, but it will get more
1775 // arguments later, so we'll keep it separate for now.
1776 patterns.add<CIRToLLVMAllocaOpLowering>(arg&: converter, args: patterns.getContext(), args&: dl);
1777 patterns.add<CIRToLLVMLoadOpLowering>(arg&: converter, args: patterns.getContext(), args&: dl);
1778 patterns.add<CIRToLLVMStoreOpLowering>(arg&: converter, args: patterns.getContext(), args&: dl);
1779 patterns.add<CIRToLLVMGlobalOpLowering>(arg&: converter, args: patterns.getContext(), args&: dl);
1780 patterns.add<CIRToLLVMCastOpLowering>(arg&: converter, args: patterns.getContext(), args&: dl);
1781 patterns.add<CIRToLLVMPtrStrideOpLowering>(arg&: converter, args: patterns.getContext(),
1782 args&: dl);
1783 patterns.add<
1784 // clang-format off
1785 CIRToLLVMBaseClassAddrOpLowering,
1786 CIRToLLVMBinOpLowering,
1787 CIRToLLVMBrCondOpLowering,
1788 CIRToLLVMBrOpLowering,
1789 CIRToLLVMCallOpLowering,
1790 CIRToLLVMCmpOpLowering,
1791 CIRToLLVMConstantOpLowering,
1792 CIRToLLVMFuncOpLowering,
1793 CIRToLLVMGetGlobalOpLowering,
1794 CIRToLLVMGetMemberOpLowering,
1795 CIRToLLVMSelectOpLowering,
1796 CIRToLLVMSwitchFlatOpLowering,
1797 CIRToLLVMShiftOpLowering,
1798 CIRToLLVMStackSaveOpLowering,
1799 CIRToLLVMStackRestoreOpLowering,
1800 CIRToLLVMTrapOpLowering,
1801 CIRToLLVMUnaryOpLowering,
1802 CIRToLLVMVecCreateOpLowering,
1803 CIRToLLVMVecExtractOpLowering,
1804 CIRToLLVMVecInsertOpLowering,
1805 CIRToLLVMVecCmpOpLowering,
1806 CIRToLLVMVecSplatOpLowering,
1807 CIRToLLVMVecShuffleOpLowering,
1808 CIRToLLVMVecShuffleDynamicOpLowering,
1809 CIRToLLVMVecTernaryOpLowering
1810 // clang-format on
1811 >(arg&: converter, args: patterns.getContext());
1812
1813 processCIRAttrs(module: module);
1814
1815 mlir::ConversionTarget target(getContext());
1816 target.addLegalOp<mlir::ModuleOp>();
1817 target.addLegalDialect<mlir::LLVM::LLVMDialect>();
1818 target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
1819 mlir::func::FuncDialect>();
1820
1821 llvm::SmallVector<mlir::Operation *> ops;
1822 ops.push_back(Elt: module);
1823 collectUnreachable(module, ops);
1824
1825 if (failed(applyPartialConversion(ops, target, std::move(patterns))))
1826 signalPassFailure();
1827}
1828
1829mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite(
1830 cir::BrOp op, OpAdaptor adaptor,
1831 mlir::ConversionPatternRewriter &rewriter) const {
1832 rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(),
1833 op.getDest());
1834 return mlir::LogicalResult::success();
1835}
1836
1837mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
1838 cir::GetMemberOp op, OpAdaptor adaptor,
1839 mlir::ConversionPatternRewriter &rewriter) const {
1840 mlir::Type llResTy = getTypeConverter()->convertType(op.getType());
1841 const auto recordTy =
1842 mlir::cast<cir::RecordType>(op.getAddrTy().getPointee());
1843 assert(recordTy && "expected record type");
1844
1845 switch (recordTy.getKind()) {
1846 case cir::RecordType::Class:
1847 case cir::RecordType::Struct: {
1848 // Since the base address is a pointer to an aggregate, the first offset
1849 // is always zero. The second offset tell us which member it will access.
1850 llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
1851 const mlir::Type elementTy = getTypeConverter()->convertType(recordTy);
1852 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy,
1853 adaptor.getAddr(), offset);
1854 return mlir::success();
1855 }
1856 case cir::RecordType::Union:
1857 // Union members share the address space, so we just need a bitcast to
1858 // conform to type-checking.
1859 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
1860 adaptor.getAddr());
1861 return mlir::success();
1862 }
1863}
1864
1865mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite(
1866 cir::TrapOp op, OpAdaptor adaptor,
1867 mlir::ConversionPatternRewriter &rewriter) const {
1868 mlir::Location loc = op->getLoc();
1869 rewriter.eraseOp(op: op);
1870
1871 rewriter.create<mlir::LLVM::Trap>(loc);
1872
1873 // Note that the call to llvm.trap is not a terminator in LLVM dialect.
1874 // So we must emit an additional llvm.unreachable to terminate the current
1875 // block.
1876 rewriter.create<mlir::LLVM::UnreachableOp>(loc);
1877
1878 return mlir::success();
1879}
1880
1881mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
1882 cir::StackSaveOp op, OpAdaptor adaptor,
1883 mlir::ConversionPatternRewriter &rewriter) const {
1884 const mlir::Type ptrTy = getTypeConverter()->convertType(op.getType());
1885 rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy);
1886 return mlir::success();
1887}
1888
1889mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
1890 cir::StackRestoreOp op, OpAdaptor adaptor,
1891 mlir::ConversionPatternRewriter &rewriter) const {
1892 rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op, adaptor.getPtr());
1893 return mlir::success();
1894}
1895
1896mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
1897 cir::VecCreateOp op, OpAdaptor adaptor,
1898 mlir::ConversionPatternRewriter &rewriter) const {
1899 // Start with an 'undef' value for the vector. Then 'insertelement' for
1900 // each of the vector elements.
1901 const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
1902 const mlir::Type llvmTy = typeConverter->convertType(vecTy);
1903 const mlir::Location loc = op.getLoc();
1904 mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1905 assert(vecTy.getSize() == op.getElements().size() &&
1906 "cir.vec.create op count doesn't match vector type elements count");
1907
1908 for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
1909 const mlir::Value indexValue =
1910 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
1911 result = rewriter.create<mlir::LLVM::InsertElementOp>(
1912 loc, result, adaptor.getElements()[i], indexValue);
1913 }
1914
1915 rewriter.replaceOp(op, result);
1916 return mlir::success();
1917}
1918
1919mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
1920 cir::VecExtractOp op, OpAdaptor adaptor,
1921 mlir::ConversionPatternRewriter &rewriter) const {
1922 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
1923 op, adaptor.getVec(), adaptor.getIndex());
1924 return mlir::success();
1925}
1926
1927mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
1928 cir::VecInsertOp op, OpAdaptor adaptor,
1929 mlir::ConversionPatternRewriter &rewriter) const {
1930 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
1931 op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
1932 return mlir::success();
1933}
1934
1935mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
1936 cir::VecCmpOp op, OpAdaptor adaptor,
1937 mlir::ConversionPatternRewriter &rewriter) const {
1938 mlir::Type elementType = elementTypeIfVector(op.getLhs().getType());
1939 mlir::Value bitResult;
1940 if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
1941 bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
1942 op.getLoc(),
1943 convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
1944 adaptor.getLhs(), adaptor.getRhs());
1945 } else if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) {
1946 bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
1947 op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
1948 adaptor.getLhs(), adaptor.getRhs());
1949 } else {
1950 return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
1951 }
1952
1953 // LLVM IR vector comparison returns a vector of i1. This one-bit vector
1954 // must be sign-extended to the correct result type.
1955 rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
1956 op, typeConverter->convertType(op.getType()), bitResult);
1957 return mlir::success();
1958}
1959
1960mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
1961 cir::VecSplatOp op, OpAdaptor adaptor,
1962 mlir::ConversionPatternRewriter &rewriter) const {
1963 // Vector splat can be implemented with an `insertelement` and a
1964 // `shufflevector`, which is better than an `insertelement` for each
1965 // element in the vector. Start with an undef vector. Insert the value into
1966 // the first element. Then use a `shufflevector` with a mask of all 0 to
1967 // fill out the entire vector with that value.
1968 cir::VectorType vecTy = op.getType();
1969 mlir::Type llvmTy = typeConverter->convertType(vecTy);
1970 mlir::Location loc = op.getLoc();
1971 mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1972
1973 mlir::Value elementValue = adaptor.getValue();
1974 if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
1975 // If the splat value is poison, then we can just use poison value
1976 // for the entire vector.
1977 rewriter.replaceOp(op, poison);
1978 return mlir::success();
1979 }
1980
1981 if (auto constValue =
1982 dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) {
1983 if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
1984 mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
1985 mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
1986 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1987 op, denseVec.getType(), denseVec);
1988 return mlir::success();
1989 }
1990
1991 if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
1992 mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
1993 mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
1994 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1995 op, denseVec.getType(), denseVec);
1996 return mlir::success();
1997 }
1998 }
1999
2000 mlir::Value indexValue =
2001 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
2002 mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
2003 loc, poison, elementValue, indexValue);
2004 SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
2005 rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
2006 poison, zeroValues);
2007 return mlir::success();
2008}
2009
2010mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
2011 cir::VecShuffleOp op, OpAdaptor adaptor,
2012 mlir::ConversionPatternRewriter &rewriter) const {
2013 // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices.
2014 // Convert the ClangIR ArrayAttr of IntAttr constants into a
2015 // SmallVector<int>.
2016 SmallVector<int, 8> indices;
2017 std::transform(
2018 op.getIndices().begin(), op.getIndices().end(),
2019 std::back_inserter(x&: indices), [](mlir::Attribute intAttr) {
2020 return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue();
2021 });
2022 rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(
2023 op, adaptor.getVec1(), adaptor.getVec2(), indices);
2024 return mlir::success();
2025}
2026
2027mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
2028 cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
2029 mlir::ConversionPatternRewriter &rewriter) const {
2030 // LLVM IR does not have an operation that corresponds to this form of
2031 // the built-in.
2032 // __builtin_shufflevector(V, I)
2033 // is implemented as this pseudocode, where the for loop is unrolled
2034 // and N is the number of elements:
2035 //
2036 // result = undef
2037 // maskbits = NextPowerOf2(N - 1)
2038 // masked = I & maskbits
2039 // for (i in 0 <= i < N)
2040 // result[i] = V[masked[i]]
2041 mlir::Location loc = op.getLoc();
2042 mlir::Value input = adaptor.getVec();
2043 mlir::Type llvmIndexVecType =
2044 getTypeConverter()->convertType(op.getIndices().getType());
2045 mlir::Type llvmIndexType = getTypeConverter()->convertType(
2046 elementTypeIfVector(op.getIndices().getType()));
2047 uint64_t numElements =
2048 mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
2049
2050 uint64_t maskBits = llvm::NextPowerOf2(A: numElements - 1) - 1;
2051 mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
2052 loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, maskBits));
2053 mlir::Value maskVector =
2054 rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
2055
2056 for (uint64_t i = 0; i < numElements; ++i) {
2057 mlir::Value idxValue =
2058 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2059 maskVector = rewriter.create<mlir::LLVM::InsertElementOp>(
2060 loc, maskVector, maskValue, idxValue);
2061 }
2062
2063 mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>(
2064 loc, llvmIndexVecType, adaptor.getIndices(), maskVector);
2065 mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
2066 loc, getTypeConverter()->convertType(op.getVec().getType()));
2067 for (uint64_t i = 0; i < numElements; ++i) {
2068 mlir::Value iValue =
2069 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2070 mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>(
2071 loc, maskedIndices, iValue);
2072 mlir::Value valueAtIndex =
2073 rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue);
2074 result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result,
2075 valueAtIndex, iValue);
2076 }
2077 rewriter.replaceOp(op, result);
2078 return mlir::success();
2079}
2080
2081mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
2082 cir::VecTernaryOp op, OpAdaptor adaptor,
2083 mlir::ConversionPatternRewriter &rewriter) const {
2084 // Convert `cond` into a vector of i1, then use that in a `select` op.
2085 mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
2086 op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
2087 rewriter.create<mlir::LLVM::ZeroOp>(
2088 op.getCond().getLoc(),
2089 typeConverter->convertType(op.getCond().getType())));
2090 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
2091 op, bitVec, adaptor.getLhs(), adaptor.getRhs());
2092 return mlir::success();
2093}
2094
2095std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
2096 return std::make_unique<ConvertCIRToLLVMPass>();
2097}
2098
2099void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
2100 mlir::populateCIRPreLoweringPasses(pm);
2101 pm.addPass(pass: createConvertCIRToLLVMPass());
2102}
2103
2104std::unique_ptr<llvm::Module>
2105lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) {
2106 llvm::TimeTraceScope scope("lower from CIR to LLVM directly");
2107
2108 mlir::MLIRContext *mlirCtx = mlirModule.getContext();
2109
2110 mlir::PassManager pm(mlirCtx);
2111 populateCIRToLLVMPasses(pm);
2112
2113 (void)mlir::applyPassManagerCLOptions(pm);
2114
2115 if (mlir::failed(Result: pm.run(op: mlirModule))) {
2116 // FIXME: Handle any errors where they occurs and return a nullptr here.
2117 report_fatal_error(
2118 reason: "The pass manager failed to lower CIR to LLVMIR dialect!");
2119 }
2120
2121 mlir::registerBuiltinDialectTranslation(context&: *mlirCtx);
2122 mlir::registerLLVMDialectTranslation(context&: *mlirCtx);
2123 mlir::registerCIRDialectTranslation(*mlirCtx);
2124
2125 llvm::TimeTraceScope translateScope("translateModuleToLLVMIR");
2126
2127 StringRef moduleName = mlirModule.getName().value_or("CIRToLLVMModule");
2128 std::unique_ptr<llvm::Module> llvmModule =
2129 mlir::translateModuleToLLVMIR(module: mlirModule, llvmContext&: llvmCtx, name: moduleName);
2130
2131 if (!llvmModule) {
2132 // FIXME: Handle any errors where they occurs and return a nullptr here.
2133 report_fatal_error(reason: "Lowering from LLVMIR dialect to llvm IR failed!");
2134 }
2135
2136 return llvmModule;
2137}
2138} // namespace direct
2139} // namespace cir
2140

source code of clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp