1//===- ImageOps.cpp - MLIR SPIR-V Image Ops ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the image operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15using namespace mlir;
16
17//===----------------------------------------------------------------------===//
18// Common utility functions
19//===----------------------------------------------------------------------===//
20
21// TODO: In the future we should model image operands better, so we can move
22// some verification into ODS.
23static LogicalResult verifyImageOperands(Operation *imageOp,
24 spirv::ImageOperandsAttr attr,
25 Operation::operand_range operands) {
26 if (!attr) {
27 if (operands.empty())
28 return success();
29
30 return imageOp->emitError(message: "the Image Operands should encode what operands "
31 "follow, as per Image Operands");
32 }
33
34 if (spirv::bitEnumContainsAll(bits: attr.getValue(),
35 bit: spirv::ImageOperands::Lod |
36 spirv::ImageOperands::Grad))
37 return imageOp->emitError(
38 message: "it is invalid to set both the Lod and Grad bits");
39
40 size_t index = 0;
41
42 // The order we process operands is important. In case of multiple argument
43 // taking operands, the arguments are ordered starting with operands having
44 // smaller-numbered bits first.
45 if (spirv::bitEnumContainsAny(bits: attr.getValue(), bit: spirv::ImageOperands::Bias)) {
46 if (!isa<spirv::ImplicitLodOpInterface>(Val: imageOp))
47 return imageOp->emitError(
48 message: "Bias is only valid with implicit-lod instructions");
49
50 if (index + 1 > operands.size())
51 return imageOp->emitError(message: "Bias operand requires 1 argument");
52
53 if (!isa<FloatType>(Val: operands[index].getType()))
54 return imageOp->emitError(message: "Bias must be a floating-point type scalar");
55
56 auto samplingOp = cast<spirv::SamplingOpInterface>(Val: imageOp);
57 auto sampledImageType =
58 cast<spirv::SampledImageType>(Val: samplingOp.getSampledImage().getType());
59 auto imageType = cast<spirv::ImageType>(Val: sampledImageType.getImageType());
60
61 if (!llvm::is_contained(Set: {spirv::Dim::Dim1D, spirv::Dim::Dim2D,
62 spirv::Dim::Dim3D, spirv::Dim::Cube},
63 Element: imageType.getDim()))
64 return imageOp->emitError(
65 message: "Bias must only be used with an image type that has "
66 "a dim operand of 1D, 2D, 3D, or Cube");
67
68 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
69 return imageOp->emitError(message: "Bias must only be used with an image type "
70 "that has a MS operand of 0");
71
72 ++index;
73 }
74
75 if (spirv::bitEnumContainsAny(bits: attr.getValue(), bit: spirv::ImageOperands::Lod)) {
76 if (!isa<spirv::ExplicitLodOpInterface>(Val: imageOp) &&
77 !isa<spirv::FetchOpInterface>(Val: imageOp))
78 return imageOp->emitError(
79 message: "Lod is only valid with explicit-lod and fetch instructions");
80
81 if (index + 1 > operands.size())
82 return imageOp->emitError(message: "Lod operand requires 1 argument");
83
84 spirv::ImageType imageType;
85
86 if (isa<spirv::SamplingOpInterface>(Val: imageOp)) {
87 if (!isa<mlir::FloatType>(Val: operands[index].getType()))
88 return imageOp->emitError(message: "for sampling operations, Lod must be a "
89 "floating-point type scalar");
90
91 auto samplingOp = cast<spirv::SamplingOpInterface>(Val: imageOp);
92 auto sampledImageType = llvm::cast<spirv::SampledImageType>(
93 Val: samplingOp.getSampledImage().getType());
94 imageType = cast<spirv::ImageType>(Val: sampledImageType.getImageType());
95 } else {
96 if (!isa<mlir::IntegerType>(Val: operands[index].getType()))
97 return imageOp->emitError(
98 message: "for fetch operations, Lod must be an integer type scalar");
99
100 auto fetchOp = cast<spirv::FetchOpInterface>(Val: imageOp);
101 imageType = cast<spirv::ImageType>(Val: fetchOp.getImage().getType());
102 }
103
104 if (!llvm::is_contained(Set: {spirv::Dim::Dim1D, spirv::Dim::Dim2D,
105 spirv::Dim::Dim3D, spirv::Dim::Cube},
106 Element: imageType.getDim()))
107 return imageOp->emitError(
108 message: "Lod must only be used with an image type that has "
109 "a dim operand of 1D, 2D, 3D, or Cube");
110
111 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
112 return imageOp->emitError(message: "Lod must only be used with an image type that "
113 "has a MS operand of 0");
114
115 ++index;
116 }
117
118 if (spirv::bitEnumContainsAny(bits: attr.getValue(), bit: spirv::ImageOperands::Grad)) {
119 if (!isa<spirv::ExplicitLodOpInterface>(Val: imageOp))
120 return imageOp->emitError(
121 message: "Grad is only valid with explicit-lod instructions");
122
123 if (index + 2 > operands.size())
124 return imageOp->emitError(
125 message: "Grad operand requires 2 arguments (scalars or vectors)");
126
127 auto samplingOp = cast<spirv::SamplingOpInterface>(Val: imageOp);
128 auto sampledImageType =
129 cast<spirv::SampledImageType>(Val: samplingOp.getSampledImage().getType());
130 auto imageType = cast<spirv::ImageType>(Val: sampledImageType.getImageType());
131
132 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
133 return imageOp->emitError(message: "Grad must only be used with an image type "
134 "that has a MS operand of 0");
135
136 int64_t numberOfComponents = 0;
137
138 auto coordVector =
139 dyn_cast<mlir::VectorType>(Val: samplingOp.getCoordinate().getType());
140 if (coordVector) {
141 numberOfComponents = coordVector.getNumElements();
142 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
143 numberOfComponents -= 1;
144 } else {
145 numberOfComponents = 1;
146 }
147
148 assert(numberOfComponents > 0);
149
150 auto dXVector = dyn_cast<mlir::VectorType>(Val: operands[index].getType());
151 auto dYVector = dyn_cast<mlir::VectorType>(Val: operands[index + 1].getType());
152 if (dXVector && dYVector) {
153 if (dXVector.getNumElements() != dYVector.getNumElements() ||
154 dXVector.getNumElements() != numberOfComponents)
155 return imageOp->emitError(
156 message: "number of components of each Grad argument must equal the number "
157 "of components in coordinate, minus the array layer component, if "
158 "present");
159
160 if (!isa<mlir::FloatType>(Val: dXVector.getElementType()) ||
161 !isa<mlir::FloatType>(Val: dYVector.getElementType()))
162 return imageOp->emitError(
163 message: "Grad arguments must be a vector of floating-point type");
164 } else if (isa<mlir::FloatType>(Val: operands[index].getType()) &&
165 isa<mlir::FloatType>(Val: operands[index + 1].getType())) {
166 if (numberOfComponents != 1)
167 return imageOp->emitError(
168 message: "number of components of each Grad argument must equal the number "
169 "of components in coordinate, minus the array layer component, if "
170 "present");
171 } else {
172 return imageOp->emitError(
173 message: "Grad arguments must be a scalar or vector of floating-point type");
174 }
175
176 index += 2;
177 }
178
179 // TODO: Add the validation rules for the following Image Operands.
180 spirv::ImageOperands noSupportOperands =
181 spirv::ImageOperands::ConstOffset | spirv::ImageOperands::Offset |
182 spirv::ImageOperands::ConstOffsets | spirv::ImageOperands::Sample |
183 spirv::ImageOperands::MinLod | spirv::ImageOperands::MakeTexelAvailable |
184 spirv::ImageOperands::MakeTexelVisible |
185 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
186
187 assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) &&
188 "unimplemented operands of Image Operands");
189 (void)noSupportOperands;
190
191 if (index < operands.size())
192 return imageOp->emitError(
193 message: "too many image operand arguments have been provided");
194
195 return success();
196}
197
198//===----------------------------------------------------------------------===//
199// spirv.ImageDrefGather
200//===----------------------------------------------------------------------===//
201
202LogicalResult spirv::ImageDrefGatherOp::verify() {
203 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
204 operands: getOperandArguments());
205}
206
207//===----------------------------------------------------------------------===//
208// spirv.ImageReadOp
209//===----------------------------------------------------------------------===//
210
211LogicalResult spirv::ImageReadOp::verify() {
212 // TODO: Do we need check for: "If the Arrayed operand is 1, then additional
213 // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
214
215 // TODO: Ideally it should be somewhere verified that "If the Image Dim
216 // operand is not SubpassData, the Image Format must not be Unknown, unless
217 // the StorageImageReadWithoutFormat Capability was declared." This function
218 // however may not be the suitable place for such verification.
219
220 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
221 operands: getOperandArguments());
222}
223
224//===----------------------------------------------------------------------===//
225// spirv.ImageWriteOp
226//===----------------------------------------------------------------------===//
227
228LogicalResult spirv::ImageWriteOp::verify() {
229 // TODO: Do we need check for: "If the Arrayed operand is 1, then additional
230 // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
231
232 // TODO: Ideally it should be somewhere verified that "The Image Format must
233 // not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
234 // declared." This function however may not be the suitable place for such
235 // verification.
236
237 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
238 operands: getOperandArguments());
239}
240
241//===----------------------------------------------------------------------===//
242// spirv.ImageQuerySize
243//===----------------------------------------------------------------------===//
244
245LogicalResult spirv::ImageQuerySizeOp::verify() {
246 spirv::ImageType imageType =
247 llvm::cast<spirv::ImageType>(Val: getImage().getType());
248 Type resultType = getResult().getType();
249
250 spirv::Dim dim = imageType.getDim();
251 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
252 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
253 switch (dim) {
254 case spirv::Dim::Dim1D:
255 case spirv::Dim::Dim2D:
256 case spirv::Dim::Dim3D:
257 case spirv::Dim::Cube:
258 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
259 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
260 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
261 return emitError(
262 message: "if Dim is 1D, 2D, 3D, or Cube, "
263 "it must also have either an MS of 1 or a Sampled of 0 or 2");
264 break;
265 case spirv::Dim::Buffer:
266 case spirv::Dim::Rect:
267 break;
268 default:
269 return emitError(message: "the Dim operand of the image type must "
270 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
271 }
272
273 unsigned componentNumber = 0;
274 switch (dim) {
275 case spirv::Dim::Dim1D:
276 case spirv::Dim::Buffer:
277 componentNumber = 1;
278 break;
279 case spirv::Dim::Dim2D:
280 case spirv::Dim::Cube:
281 case spirv::Dim::Rect:
282 componentNumber = 2;
283 break;
284 case spirv::Dim::Dim3D:
285 componentNumber = 3;
286 break;
287 default:
288 break;
289 }
290
291 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
292 componentNumber += 1;
293
294 unsigned resultComponentNumber = 1;
295 if (auto resultVectorType = llvm::dyn_cast<VectorType>(Val&: resultType))
296 resultComponentNumber = resultVectorType.getNumElements();
297
298 if (componentNumber != resultComponentNumber)
299 return emitError(message: "expected the result to have ")
300 << componentNumber << " component(s), but found "
301 << resultComponentNumber << " component(s)";
302
303 return success();
304}
305
306//===----------------------------------------------------------------------===//
307// spirv.ImageSampleImplicitLod
308//===----------------------------------------------------------------------===//
309
310LogicalResult spirv::ImageSampleImplicitLodOp::verify() {
311 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
312 operands: getOperandArguments());
313}
314
315//===----------------------------------------------------------------------===//
316// spirv.ImageSampleExplicitLod
317//===----------------------------------------------------------------------===//
318
319LogicalResult spirv::ImageSampleExplicitLodOp::verify() {
320 // TODO: It should be verified somewhere that: "Unless the Kernel capability
321 // is declared, it [Coordinate] must be floating point."
322
323 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
324 operands: getOperandArguments());
325}
326
327//===----------------------------------------------------------------------===//
328// spirv.ImageSampleProjDrefImplicitLod
329//===----------------------------------------------------------------------===//
330
331LogicalResult spirv::ImageSampleProjDrefImplicitLodOp::verify() {
332 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
333 operands: getOperandArguments());
334}
335
336//===----------------------------------------------------------------------===//
337// spirv.ImageFetchOp
338//===----------------------------------------------------------------------===//
339
340LogicalResult spirv::ImageFetchOp::verify() {
341 return verifyImageOperands(imageOp: getOperation(), attr: getImageOperandsAttr(),
342 operands: getOperandArguments());
343}
344

source code of mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp