1//===- ImageOps.cpp - MLIR SPIR-V Image Ops ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the image operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15using namespace mlir;
16
17//===----------------------------------------------------------------------===//
18// Common utility functions
19//===----------------------------------------------------------------------===//
20
21// TODO: In the future we should model image operands better, so we can move
22// some verification into ODS.
23static LogicalResult verifyImageOperands(Operation *imageOp,
24 spirv::ImageOperandsAttr attr,
25 Operation::operand_range operands) {
26 if (!attr) {
27 if (operands.empty())
28 return success();
29
30 return imageOp->emitError(message: "the Image Operands should encode what operands "
31 "follow, as per Image Operands");
32 }
33
34 if (spirv::bitEnumContainsAll(attr.getValue(),
35 spirv::ImageOperands::Lod |
36 spirv::ImageOperands::Grad))
37 return imageOp->emitError(
38 message: "it is invalid to set both the Lod and Grad bits");
39
40 size_t index = 0;
41
42 // The order we process operands is important. In case of multiple argument
43 // taking operands, the arguments are ordered starting with operands having
44 // smaller-numbered bits first.
45 if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Bias)) {
46 if (!isa<spirv::ImplicitLodOpInterface>(imageOp))
47 return imageOp->emitError(
48 message: "Bias is only valid with implicit-lod instructions");
49
50 if (index + 1 > operands.size())
51 return imageOp->emitError(message: "Bias operand requires 1 argument");
52
53 if (!isa<FloatType>(Val: operands[index].getType()))
54 return imageOp->emitError(message: "Bias must be a floating-point type scalar");
55
56 auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp);
57 auto sampledImageType =
58 cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType());
59 auto imageType = cast<spirv::ImageType>(sampledImageType.getImageType());
60
61 if (!llvm::is_contained({spirv::Dim::Dim1D, spirv::Dim::Dim2D,
62 spirv::Dim::Dim3D, spirv::Dim::Cube},
63 imageType.getDim()))
64 return imageOp->emitError(
65 message: "Bias must only be used with an image type that has "
66 "a dim operand of 1D, 2D, 3D, or Cube");
67
68 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
69 return imageOp->emitError(message: "Bias must only be used with an image type "
70 "that has a MS operand of 0");
71
72 ++index;
73 }
74
75 if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Lod)) {
76 if (!isa<spirv::ExplicitLodOpInterface>(imageOp) &&
77 !isa<spirv::FetchOpInterface>(imageOp))
78 return imageOp->emitError(
79 message: "Lod is only valid with explicit-lod and fetch instructions");
80
81 if (index + 1 > operands.size())
82 return imageOp->emitError(message: "Lod operand requires 1 argument");
83
84 spirv::ImageType imageType;
85
86 if (isa<spirv::SamplingOpInterface>(imageOp)) {
87 if (!isa<mlir::FloatType>(Val: operands[index].getType()))
88 return imageOp->emitError(message: "for sampling operations, Lod must be a "
89 "floating-point type scalar");
90
91 auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp);
92 auto sampledImageType = llvm::cast<spirv::SampledImageType>(
93 samplingOp.getSampledImage().getType());
94 imageType = cast<spirv::ImageType>(sampledImageType.getImageType());
95 } else {
96 if (!isa<mlir::IntegerType>(Val: operands[index].getType()))
97 return imageOp->emitError(
98 message: "for fetch operations, Lod must be an integer type scalar");
99
100 auto fetchOp = cast<spirv::FetchOpInterface>(imageOp);
101 imageType = cast<spirv::ImageType>(fetchOp.getImage().getType());
102 }
103
104 if (!llvm::is_contained({spirv::Dim::Dim1D, spirv::Dim::Dim2D,
105 spirv::Dim::Dim3D, spirv::Dim::Cube},
106 imageType.getDim()))
107 return imageOp->emitError(
108 message: "Lod must only be used with an image type that has "
109 "a dim operand of 1D, 2D, 3D, or Cube");
110
111 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
112 return imageOp->emitError(message: "Lod must only be used with an image type that "
113 "has a MS operand of 0");
114
115 ++index;
116 }
117
118 if (spirv::bitEnumContainsAny(attr.getValue(), spirv::ImageOperands::Grad)) {
119 if (!isa<spirv::ExplicitLodOpInterface>(imageOp))
120 return imageOp->emitError(
121 message: "Grad is only valid with explicit-lod instructions");
122
123 if (index + 2 > operands.size())
124 return imageOp->emitError(
125 message: "Grad operand requires 2 arguments (scalars or vectors)");
126
127 auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp);
128 auto sampledImageType =
129 cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType());
130 auto imageType = cast<spirv::ImageType>(sampledImageType.getImageType());
131
132 if (imageType.getSamplingInfo() != spirv::ImageSamplingInfo::SingleSampled)
133 return imageOp->emitError(message: "Grad must only be used with an image type "
134 "that has a MS operand of 0");
135
136 int64_t numberOfComponents = 0;
137
138 auto coordVector =
139 dyn_cast<mlir::VectorType>(samplingOp.getCoordinate().getType());
140 if (coordVector) {
141 numberOfComponents = coordVector.getNumElements();
142 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
143 numberOfComponents -= 1;
144 } else {
145 numberOfComponents = 1;
146 }
147
148 assert(numberOfComponents > 0);
149
150 auto dXVector = dyn_cast<mlir::VectorType>(operands[index].getType());
151 auto dYVector = dyn_cast<mlir::VectorType>(operands[index + 1].getType());
152 if (dXVector && dYVector) {
153 if (dXVector.getNumElements() != dYVector.getNumElements() ||
154 dXVector.getNumElements() != numberOfComponents)
155 return imageOp->emitError(
156 message: "number of components of each Grad argument must equal the number "
157 "of components in coordinate, minus the array layer component, if "
158 "present");
159
160 if (!isa<mlir::FloatType>(dXVector.getElementType()) ||
161 !isa<mlir::FloatType>(dYVector.getElementType()))
162 return imageOp->emitError(
163 message: "Grad arguments must be a vector of floating-point type");
164 } else if (isa<mlir::FloatType>(Val: operands[index].getType()) &&
165 isa<mlir::FloatType>(Val: operands[index + 1].getType())) {
166 if (numberOfComponents != 1)
167 return imageOp->emitError(
168 message: "number of components of each Grad argument must equal the number "
169 "of components in coordinate, minus the array layer component, if "
170 "present");
171 } else {
172 return imageOp->emitError(
173 message: "Grad arguments must be a scalar or vector of floating-point type");
174 }
175
176 index += 2;
177 }
178
179 // TODO: Add the validation rules for the following Image Operands.
180 spirv::ImageOperands noSupportOperands =
181 spirv::ImageOperands::ConstOffset | spirv::ImageOperands::Offset |
182 spirv::ImageOperands::ConstOffsets | spirv::ImageOperands::Sample |
183 spirv::ImageOperands::MinLod | spirv::ImageOperands::MakeTexelAvailable |
184 spirv::ImageOperands::MakeTexelVisible |
185 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
186
187 assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) &&
188 "unimplemented operands of Image Operands");
189 (void)noSupportOperands;
190
191 if (index < operands.size())
192 return imageOp->emitError(
193 message: "too many image operand arguments have been provided");
194
195 return success();
196}
197
198//===----------------------------------------------------------------------===//
199// spirv.ImageDrefGather
200//===----------------------------------------------------------------------===//
201
202LogicalResult spirv::ImageDrefGatherOp::verify() {
203 return verifyImageOperands(getOperation(), getImageOperandsAttr(),
204 getOperandArguments());
205}
206
207//===----------------------------------------------------------------------===//
208// spirv.ImageWriteOp
209//===----------------------------------------------------------------------===//
210
211LogicalResult spirv::ImageWriteOp::verify() {
212 // TODO: Do we need check for: "If the Arrayed operand is 1, then additional
213 // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
214
215 // TODO: Ideally it should be somewhere verified that "The Image Format must
216 // not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
217 // declared." This function however may not be the suitable place for such
218 // verification.
219
220 return verifyImageOperands(getOperation(), getImageOperandsAttr(),
221 getOperandArguments());
222}
223
224//===----------------------------------------------------------------------===//
225// spirv.ImageQuerySize
226//===----------------------------------------------------------------------===//
227
228LogicalResult spirv::ImageQuerySizeOp::verify() {
229 spirv::ImageType imageType =
230 llvm::cast<spirv::ImageType>(getImage().getType());
231 Type resultType = getResult().getType();
232
233 spirv::Dim dim = imageType.getDim();
234 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
235 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
236 switch (dim) {
237 case spirv::Dim::Dim1D:
238 case spirv::Dim::Dim2D:
239 case spirv::Dim::Dim3D:
240 case spirv::Dim::Cube:
241 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
242 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
243 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
244 return emitError(
245 "if Dim is 1D, 2D, 3D, or Cube, "
246 "it must also have either an MS of 1 or a Sampled of 0 or 2");
247 break;
248 case spirv::Dim::Buffer:
249 case spirv::Dim::Rect:
250 break;
251 default:
252 return emitError("the Dim operand of the image type must "
253 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
254 }
255
256 unsigned componentNumber = 0;
257 switch (dim) {
258 case spirv::Dim::Dim1D:
259 case spirv::Dim::Buffer:
260 componentNumber = 1;
261 break;
262 case spirv::Dim::Dim2D:
263 case spirv::Dim::Cube:
264 case spirv::Dim::Rect:
265 componentNumber = 2;
266 break;
267 case spirv::Dim::Dim3D:
268 componentNumber = 3;
269 break;
270 default:
271 break;
272 }
273
274 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
275 componentNumber += 1;
276
277 unsigned resultComponentNumber = 1;
278 if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
279 resultComponentNumber = resultVectorType.getNumElements();
280
281 if (componentNumber != resultComponentNumber)
282 return emitError("expected the result to have ")
283 << componentNumber << " component(s), but found "
284 << resultComponentNumber << " component(s)";
285
286 return success();
287}
288
289//===----------------------------------------------------------------------===//
290// spirv.ImageSampleImplicitLod
291//===----------------------------------------------------------------------===//
292
293LogicalResult spirv::ImageSampleImplicitLodOp::verify() {
294 return verifyImageOperands(getOperation(), getImageOperandsAttr(),
295 getOperandArguments());
296}
297
298//===----------------------------------------------------------------------===//
299// spirv.ImageSampleExplicitLod
300//===----------------------------------------------------------------------===//
301
302LogicalResult spirv::ImageSampleExplicitLodOp::verify() {
303 // TODO: It should be verified somewhere that: "Unless the Kernel capability
304 // is declared, it [Coordinate] must be floating point."
305
306 return verifyImageOperands(getOperation(), getImageOperandsAttr(),
307 getOperandArguments());
308}
309
310//===----------------------------------------------------------------------===//
311// spirv.ImageSampleProjDrefImplicitLod
312//===----------------------------------------------------------------------===//
313
314LogicalResult spirv::ImageSampleProjDrefImplicitLodOp::verify() {
315 return verifyImageOperands(getOperation(), getImageOperandsAttr(),
316 getOperandArguments());
317}
318

source code of mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp