1//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Pointer dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Ptr/IR/PtrOps.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/Interfaces/DataLayoutInterfaces.h"
17#include "mlir/Transforms/InliningUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21using namespace mlir::ptr;
22
23//===----------------------------------------------------------------------===//
24// Pointer dialect
25//===----------------------------------------------------------------------===//
26
27void PtrDialect::initialize() {
28 addOperations<
29#define GET_OP_LIST
30#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
31 >();
32 addAttributes<
33#define GET_ATTRDEF_LIST
34#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
35 >();
36 addTypes<
37#define GET_TYPEDEF_LIST
38#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
39 >();
40}
41
42//===----------------------------------------------------------------------===//
43// FromPtrOp
44//===----------------------------------------------------------------------===//
45
46OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
47 // Fold the pattern:
48 // %ptr = ptr.to_ptr %v : type -> ptr
49 // (%mda = ptr.get_metadata %v : type)?
50 // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
51 // To:
52 // %val -> %v
53 Value ptrLike;
54 FromPtrOp fromPtr = *this;
55 while (fromPtr != nullptr) {
56 auto toPtr = dyn_cast_or_null<ToPtrOp>(Val: fromPtr.getPtr().getDefiningOp());
57 // Cannot fold if it's not a `to_ptr` op or the initial and final types are
58 // different.
59 if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
60 return ptrLike;
61 Value md = fromPtr.getMetadata();
62 // If the type has trivial metadata fold.
63 if (!fromPtr.getType().hasPtrMetadata()) {
64 ptrLike = toPtr.getPtr();
65 } else if (md) {
66 // Fold if the metadata can be verified to be equal.
67 if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(Val: md.getDefiningOp());
68 mdOp && mdOp.getPtr() == toPtr.getPtr())
69 ptrLike = toPtr.getPtr();
70 }
71 // Check for a sequence of casts.
72 fromPtr = dyn_cast_or_null<FromPtrOp>(Val: ptrLike ? ptrLike.getDefiningOp()
73 : nullptr);
74 }
75 return ptrLike;
76}
77
78LogicalResult FromPtrOp::verify() {
79 if (isa<PtrType>(Val: getType()))
80 return emitError() << "the result type cannot be `!ptr.ptr`";
81 if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
82 return emitError()
83 << "expected the input and output to have the same memory space";
84 }
85 return success();
86}
87
88//===----------------------------------------------------------------------===//
89// PtrAddOp
90//===----------------------------------------------------------------------===//
91
92/// Fold: ptradd ptr + 0 -> ptr
93OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
94 Attribute attr = adaptor.getOffset();
95 if (!attr)
96 return nullptr;
97 if (llvm::APInt value; m_ConstantInt(bind_value: &value).match(attr) && value.isZero())
98 return getBase();
99 return nullptr;
100}
101
102//===----------------------------------------------------------------------===//
103// ToPtrOp
104//===----------------------------------------------------------------------===//
105
106OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
107 // Fold the pattern:
108 // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
109 // %ptr = ptr.to_ptr %val : type -> ptr
110 // To:
111 // %ptr -> %p
112 Value ptr;
113 ToPtrOp toPtr = *this;
114 while (toPtr != nullptr) {
115 auto fromPtr = dyn_cast_or_null<FromPtrOp>(Val: toPtr.getPtr().getDefiningOp());
116 // Cannot fold if it's not a `from_ptr` op.
117 if (!fromPtr)
118 return ptr;
119 ptr = fromPtr.getPtr();
120 // Check for chains of casts.
121 toPtr = dyn_cast_or_null<ToPtrOp>(Val: ptr.getDefiningOp());
122 }
123 return ptr;
124}
125
126LogicalResult ToPtrOp::verify() {
127 if (isa<PtrType>(Val: getPtr().getType()))
128 return emitError() << "the input value cannot be of type `!ptr.ptr`";
129 if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
130 return emitError()
131 << "expected the input and output to have the same memory space";
132 }
133 return success();
134}
135
136//===----------------------------------------------------------------------===//
137// TypeOffsetOp
138//===----------------------------------------------------------------------===//
139
140llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
141 if (layout)
142 return layout->getTypeSize(t: getElementType());
143 DataLayout dl = DataLayout::closest(op: *this);
144 return dl.getTypeSize(t: getElementType());
145}
146
147//===----------------------------------------------------------------------===//
148// Pointer API.
149//===----------------------------------------------------------------------===//
150
151#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
152
153#define GET_ATTRDEF_CLASSES
154#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
155
156#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
157
158#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
159
160#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
161
162#define GET_TYPEDEF_CLASSES
163#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
164
165#define GET_OP_CLASSES
166#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
167

source code of mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp