1 | //===- ImageOps.cpp - MLIR SPIR-V Image Ops ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the image operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | // Common utility functions |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | // TODO: In the future we should model image operands better, so we can move |
22 | // some verification into ODS. |
23 | static LogicalResult verifyImageOperands(Operation *imageOp, |
24 | spirv::ImageOperandsAttr attr, |
25 | Operation::operand_range operands) { |
26 | if (!attr) { |
27 | if (operands.empty()) |
28 | return success(); |
29 | |
30 | return imageOp->emitError(message: "the Image Operands should encode what operands " |
31 | "follow, as per Image Operands" ); |
32 | } |
33 | |
34 | if (spirv::bitEnumContainsAll(attr.getValue(), |
35 | spirv::ImageOperands::Lod | |
36 | spirv::ImageOperands::Grad)) |
37 | return imageOp->emitError( |
38 | message: "it is invalid to set both the Lod and Grad bits" ); |
39 | |
40 | size_t index = 0; |
41 | |
42 | // The order we process operands is important. In case of multiple argument |
43 | // taking operands, the arguments are ordered starting with operands having |
44 | // smaller-numbered bits first. |
45 | if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Bias)) { |
46 | if (!isa<spirv::ImplicitLodOpInterface>(imageOp)) |
47 | return imageOp->emitError( |
48 | message: "Bias is only valid with implicit-lod instructions" ); |
49 | |
50 | if (index + 1 > operands.size()) |
51 | return imageOp->emitError(message: "Bias operand requires 1 argument" ); |
52 | |
53 | if (!isa<FloatType>(Val: operands[index].getType())) |
54 | return imageOp->emitError(message: "Bias must be a floating-point type scalar" ); |
55 | |
56 | auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp); |
57 | auto sampledImageType = |
58 | cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType()); |
59 | auto imageType = cast<spirv::ImageType>(sampledImageType.getImageType()); |
60 | |
61 | if (!llvm::is_contained({spirv::Dim::Dim1D, spirv::Dim::Dim2D, |
62 | spirv::Dim::Dim3D, spirv::Dim::Cube}, |
63 | imageType.getDim())) |
64 | return imageOp->emitError( |
65 | message: "Bias must only be used with an image type that has " |
66 | "a dim operand of 1D, 2D, 3D, or Cube" ); |
67 | |
68 | if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled) |
69 | return imageOp->emitError(message: "Bias must only be used with an image type " |
70 | "that has a MS operand of 0" ); |
71 | |
72 | ++index; |
73 | } |
74 | |
75 | if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Lod)) { |
76 | if (!isa<spirv::ExplicitLodOpInterface>(imageOp) && |
77 | !isa<spirv::FetchOpInterface>(imageOp)) |
78 | return imageOp->emitError( |
79 | message: "Lod is only valid with explicit-lod and fetch instructions" ); |
80 | |
81 | if (index + 1 > operands.size()) |
82 | return imageOp->emitError(message: "Lod operand requires 1 argument" ); |
83 | |
84 | spirv::ImageType imageType; |
85 | |
86 | if (isa<spirv::SamplingOpInterface>(imageOp)) { |
87 | if (!isa<mlir::FloatType>(Val: operands[index].getType())) |
88 | return imageOp->emitError(message: "for sampling operations, Lod must be a " |
89 | "floating-point type scalar" ); |
90 | |
91 | auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp); |
92 | auto sampledImageType = llvm::cast<spirv::SampledImageType>( |
93 | samplingOp.getSampledImage().getType()); |
94 | imageType = cast<spirv::ImageType>(sampledImageType.getImageType()); |
95 | } else { |
96 | if (!isa<mlir::IntegerType>(Val: operands[index].getType())) |
97 | return imageOp->emitError( |
98 | message: "for fetch operations, Lod must be an integer type scalar" ); |
99 | |
100 | auto fetchOp = cast<spirv::FetchOpInterface>(imageOp); |
101 | imageType = cast<spirv::ImageType>(fetchOp.getImage().getType()); |
102 | } |
103 | |
104 | if (!llvm::is_contained({spirv::Dim::Dim1D, spirv::Dim::Dim2D, |
105 | spirv::Dim::Dim3D, spirv::Dim::Cube}, |
106 | imageType.getDim())) |
107 | return imageOp->emitError( |
108 | message: "Lod must only be used with an image type that has " |
109 | "a dim operand of 1D, 2D, 3D, or Cube" ); |
110 | |
111 | if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled) |
112 | return imageOp->emitError(message: "Lod must only be used with an image type that " |
113 | "has a MS operand of 0" ); |
114 | |
115 | ++index; |
116 | } |
117 | |
118 | if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Grad)) { |
119 | if (!isa<spirv::ExplicitLodOpInterface>(imageOp)) |
120 | return imageOp->emitError( |
121 | message: "Grad is only valid with explicit-lod instructions" ); |
122 | |
123 | if (index + 2 > operands.size()) |
124 | return imageOp->emitError( |
125 | message: "Grad operand requires 2 arguments (scalars or vectors)" ); |
126 | |
127 | auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp); |
128 | auto sampledImageType = |
129 | cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType()); |
130 | auto imageType = cast<spirv::ImageType>(sampledImageType.getImageType()); |
131 | |
132 | if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled) |
133 | return imageOp->emitError(message: "Grad must only be used with an image type " |
134 | "that has a MS operand of 0" ); |
135 | |
136 | int64_t numberOfComponents = 0; |
137 | |
138 | auto coordVector = |
139 | dyn_cast<mlir::VectorType>(samplingOp.getCoordinate().getType()); |
140 | if (coordVector) { |
141 | numberOfComponents = coordVector.getNumElements(); |
142 | if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) |
143 | numberOfComponents -= 1; |
144 | } else { |
145 | numberOfComponents = 1; |
146 | } |
147 | |
148 | assert(numberOfComponents > 0); |
149 | |
150 | auto dXVector = dyn_cast<mlir::VectorType>(operands[index].getType()); |
151 | auto dYVector = dyn_cast<mlir::VectorType>(operands[index + 1].getType()); |
152 | if (dXVector && dYVector) { |
153 | if (dXVector.getNumElements() != dYVector.getNumElements() || |
154 | dXVector.getNumElements() != numberOfComponents) |
155 | return imageOp->emitError( |
156 | message: "number of components of each Grad argument must equal the number " |
157 | "of components in coordinate, minus the array layer component, if " |
158 | "present" ); |
159 | |
160 | if (!isa<mlir::FloatType>(dXVector.getElementType()) || |
161 | !isa<mlir::FloatType>(dYVector.getElementType())) |
162 | return imageOp->emitError( |
163 | message: "Grad arguments must be a vector of floating-point type" ); |
164 | } else if (isa<mlir::FloatType>(Val: operands[index].getType()) && |
165 | isa<mlir::FloatType>(Val: operands[index + 1].getType())) { |
166 | if (numberOfComponents != 1) |
167 | return imageOp->emitError( |
168 | message: "number of components of each Grad argument must equal the number " |
169 | "of components in coordinate, minus the array layer component, if " |
170 | "present" ); |
171 | } else { |
172 | return imageOp->emitError( |
173 | message: "Grad arguments must be a scalar or vector of floating-point type" ); |
174 | } |
175 | |
176 | index += 2; |
177 | } |
178 | |
179 | // TODO: Add the validation rules for the following Image Operands. |
180 | spirv::ImageOperands noSupportOperands = |
181 | spirv::ImageOperands::ConstOffset | spirv::ImageOperands::Offset | |
182 | spirv::ImageOperands::ConstOffsets | spirv::ImageOperands::Sample | |
183 | spirv::ImageOperands::MinLod | spirv::ImageOperands::MakeTexelAvailable | |
184 | spirv::ImageOperands::MakeTexelVisible | |
185 | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; |
186 | |
187 | assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) && |
188 | "unimplemented operands of Image Operands" ); |
189 | (void)noSupportOperands; |
190 | |
191 | if (index < operands.size()) |
192 | return imageOp->emitError( |
193 | message: "too many image operand arguments have been provided" ); |
194 | |
195 | return success(); |
196 | } |
197 | |
198 | //===----------------------------------------------------------------------===// |
199 | // spirv.ImageDrefGather |
200 | //===----------------------------------------------------------------------===// |
201 | |
202 | LogicalResult spirv::ImageDrefGatherOp::verify() { |
203 | return verifyImageOperands(getOperation(), getImageOperandsAttr(), |
204 | getOperandArguments()); |
205 | } |
206 | |
207 | //===----------------------------------------------------------------------===// |
208 | // spirv.ImageWriteOp |
209 | //===----------------------------------------------------------------------===// |
210 | |
211 | LogicalResult spirv::ImageWriteOp::verify() { |
212 | // TODO: Do we need check for: "If the Arrayed operand is 1, then additional |
213 | // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."? |
214 | |
215 | // TODO: Ideally it should be somewhere verified that "The Image Format must |
216 | // not be Unknown, unless the StorageImageWriteWithoutFormat Capability was |
217 | // declared." This function however may not be the suitable place for such |
218 | // verification. |
219 | |
220 | return verifyImageOperands(getOperation(), getImageOperandsAttr(), |
221 | getOperandArguments()); |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // spirv.ImageQuerySize |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | LogicalResult spirv::ImageQuerySizeOp::verify() { |
229 | spirv::ImageType imageType = |
230 | llvm::cast<spirv::ImageType>(getImage().getType()); |
231 | Type resultType = getResult().getType(); |
232 | |
233 | spirv::Dim dim = imageType.getDim(); |
234 | spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); |
235 | spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); |
236 | switch (dim) { |
237 | case spirv::Dim::Dim1D: |
238 | case spirv::Dim::Dim2D: |
239 | case spirv::Dim::Dim3D: |
240 | case spirv::Dim::Cube: |
241 | if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && |
242 | samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && |
243 | samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) |
244 | return emitError( |
245 | "if Dim is 1D, 2D, 3D, or Cube, " |
246 | "it must also have either an MS of 1 or a Sampled of 0 or 2" ); |
247 | break; |
248 | case spirv::Dim::Buffer: |
249 | case spirv::Dim::Rect: |
250 | break; |
251 | default: |
252 | return emitError("the Dim operand of the image type must " |
253 | "be 1D, 2D, 3D, Buffer, Cube, or Rect" ); |
254 | } |
255 | |
256 | unsigned componentNumber = 0; |
257 | switch (dim) { |
258 | case spirv::Dim::Dim1D: |
259 | case spirv::Dim::Buffer: |
260 | componentNumber = 1; |
261 | break; |
262 | case spirv::Dim::Dim2D: |
263 | case spirv::Dim::Cube: |
264 | case spirv::Dim::Rect: |
265 | componentNumber = 2; |
266 | break; |
267 | case spirv::Dim::Dim3D: |
268 | componentNumber = 3; |
269 | break; |
270 | default: |
271 | break; |
272 | } |
273 | |
274 | if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) |
275 | componentNumber += 1; |
276 | |
277 | unsigned resultComponentNumber = 1; |
278 | if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType)) |
279 | resultComponentNumber = resultVectorType.getNumElements(); |
280 | |
281 | if (componentNumber != resultComponentNumber) |
282 | return emitError("expected the result to have " ) |
283 | << componentNumber << " component(s), but found " |
284 | << resultComponentNumber << " component(s)" ; |
285 | |
286 | return success(); |
287 | } |
288 | |
289 | //===----------------------------------------------------------------------===// |
290 | // spirv.ImageSampleImplicitLod |
291 | //===----------------------------------------------------------------------===// |
292 | |
293 | LogicalResult spirv::ImageSampleImplicitLodOp::verify() { |
294 | return verifyImageOperands(getOperation(), getImageOperandsAttr(), |
295 | getOperandArguments()); |
296 | } |
297 | |
298 | //===----------------------------------------------------------------------===// |
299 | // spirv.ImageSampleExplicitLod |
300 | //===----------------------------------------------------------------------===// |
301 | |
302 | LogicalResult spirv::ImageSampleExplicitLodOp::verify() { |
303 | // TODO: It should be verified somewhere that: "Unless the Kernel capability |
304 | // is declared, it [Coordinate] must be floating point." |
305 | |
306 | return verifyImageOperands(getOperation(), getImageOperandsAttr(), |
307 | getOperandArguments()); |
308 | } |
309 | |
310 | //===----------------------------------------------------------------------===// |
311 | // spirv.ImageSampleProjDrefImplicitLod |
312 | //===----------------------------------------------------------------------===// |
313 | |
314 | LogicalResult spirv::ImageSampleProjDrefImplicitLodOp::verify() { |
315 | return verifyImageOperands(getOperation(), getImageOperandsAttr(), |
316 | getOperandArguments()); |
317 | } |
318 | |