1//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SCF dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Context
24//===----------------------------------------------------------------------===//
25
26namespace mlir {
27struct ScfToSPIRVContextImpl {
28 // Map between the spirv region control flow operation (spirv.mlir.loop or
29 // spirv.mlir.selection) to the VariableOp created to store the region
30 // results. The order of the VariableOp matches the order of the results.
31 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
32};
33} // namespace mlir
34
35/// We use ScfToSPIRVContext to store information about the lowering of the scf
36/// region that need to be used later on. When we lower scf.for/scf.if we create
37/// VariableOp to store the results. We need to keep track of the VariableOp
38/// created as we need to insert stores into them when lowering Yield. Those
39/// StoreOp cannot be created earlier as they may use a different type than
40/// yield operands.
41ScfToSPIRVContext::ScfToSPIRVContext() {
42 impl = std::make_unique<::ScfToSPIRVContextImpl>();
43}
44
45ScfToSPIRVContext::~ScfToSPIRVContext() = default;
46
47namespace {
48
49//===----------------------------------------------------------------------===//
50// Helper Functions
51//===----------------------------------------------------------------------===//
52
53/// Replaces SCF op outputs with SPIR-V variable loads.
54/// We create VariableOp to handle the results value of the control flow region.
55/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
56/// after the loop we load the value from the allocation and use it as the SCF
57/// op result.
58template <typename ScfOp, typename OpTy>
59void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
60 ConversionPatternRewriter &rewriter,
61 ScfToSPIRVContextImpl *scfToSPIRVContext,
62 ArrayRef<Type> returnTypes) {
63
64 Location loc = scfOp.getLoc();
65 auto &allocas = scfToSPIRVContext->outputVars[newOp];
66 // Clearing the allocas is necessary in case a dialect conversion path failed
67 // previously, and this is the second attempt of this conversion.
68 allocas.clear();
69 SmallVector<Value, 8> resultValue;
70 for (Type convertedType : returnTypes) {
71 auto pointerType =
72 spirv::PointerType::get(pointeeType: convertedType, storageClass: spirv::StorageClass::Function);
73 rewriter.setInsertionPoint(newOp);
74 auto alloc = rewriter.create<spirv::VariableOp>(
75 location: loc, args&: pointerType, args: spirv::StorageClass::Function,
76 /*initializer=*/args: nullptr);
77 allocas.push_back(alloc);
78 rewriter.setInsertionPointAfter(newOp);
79 Value loadResult = rewriter.create<spirv::LoadOp>(location: loc, args&: alloc);
80 resultValue.push_back(Elt: loadResult);
81 }
82 rewriter.replaceOp(scfOp, resultValue);
83}
84
85Region::iterator getBlockIt(Region &region, unsigned index) {
86 return std::next(x: region.begin(), n: index);
87}
88
89//===----------------------------------------------------------------------===//
90// Conversion Patterns
91//===----------------------------------------------------------------------===//
92
93/// Common class for all vector to GPU patterns.
94template <typename OpTy>
95class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
96public:
97 SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
98 ScfToSPIRVContextImpl *scfToSPIRVContext)
99 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
100 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
101
102protected:
103 ScfToSPIRVContextImpl *scfToSPIRVContext;
104 // FIXME: We explicitly keep a reference of the type converter here instead of
105 // passing it to OpConversionPattern during construction. This effectively
106 // bypasses the conversion framework's automation on type conversion. This is
107 // needed right now because the conversion framework will unconditionally
108 // legalize all types used by SCF ops upon discovering them, for example, the
109 // types of loop carried values. We use SPIR-V variables for those loop
110 // carried values. Depending on the available capabilities, the SPIR-V
111 // variable can be different, for example, cooperative matrix or normal
112 // variable. We'd like to detach the conversion of the loop carried values
113 // from the SCF ops (which is mainly a region). So we need to "mark" types
114 // used by SCF ops as legal, if to use the conversion framework for type
115 // conversion. There isn't a straightforward way to do that yet, as when
116 // converting types, ops aren't taken into consideration. Therefore, we just
117 // bypass the framework's type conversion for now.
118 const SPIRVTypeConverter &typeConverter;
119};
120
121//===----------------------------------------------------------------------===//
122// scf::ForOp
123//===----------------------------------------------------------------------===//
124
125/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
126struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
127 using SCFToSPIRVPattern::SCFToSPIRVPattern;
128
129 LogicalResult
130 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 // scf::ForOp can be lowered to the structured control flow represented by
133 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
134 // latch and the merge block the exit block. The resulting spirv::LoopOp has
135 // a single back edge from the continue to header block, and a single exit
136 // from header to merge.
137 auto loc = forOp.getLoc();
138 auto loopOp = rewriter.create<spirv::LoopOp>(location: loc, args: spirv::LoopControl::None);
139 loopOp.addEntryAndMergeBlock(builder&: rewriter);
140
141 OpBuilder::InsertionGuard guard(rewriter);
142 // Create the block for the header.
143 Block *header = rewriter.createBlock(parent: &loopOp.getBody(),
144 insertPt: getBlockIt(region&: loopOp.getBody(), index: 1));
145 rewriter.setInsertionPointAfter(loopOp);
146
147 // Create the new induction variable to use.
148 Value adapLowerBound = adaptor.getLowerBound();
149 BlockArgument newIndVar =
150 header->addArgument(type: adapLowerBound.getType(), loc: adapLowerBound.getLoc());
151 for (Value arg : adaptor.getInitArgs())
152 header->addArgument(type: arg.getType(), loc: arg.getLoc());
153 Block *body = forOp.getBody();
154
155 // Apply signature conversion to the body of the forOp. It has a single
156 // block, with argument which is the induction variable. That has to be
157 // replaced with the new induction variable.
158 TypeConverter::SignatureConversion signatureConverter(
159 body->getNumArguments());
160 signatureConverter.remapInput(origInputNo: 0, replacements: newIndVar);
161 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
162 signatureConverter.remapInput(origInputNo: i, replacements: header->getArgument(i));
163 body = rewriter.applySignatureConversion(block: &forOp.getRegion().front(),
164 conversion&: signatureConverter);
165
166 // Move the blocks from the forOp into the loopOp. This is the body of the
167 // loopOp.
168 rewriter.inlineRegionBefore(region&: forOp->getRegion(index: 0), parent&: loopOp.getBody(),
169 before: getBlockIt(region&: loopOp.getBody(), index: 2));
170
171 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
172 args.append(in_start: adaptor.getInitArgs().begin(), in_end: adaptor.getInitArgs().end());
173 // Branch into it from the entry.
174 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
175 rewriter.create<spirv::BranchOp>(location: loc, args&: header, args);
176
177 // Generate the rest of the loop header.
178 rewriter.setInsertionPointToEnd(header);
179 auto *mergeBlock = loopOp.getMergeBlock();
180 auto cmpOp = rewriter.create<spirv::SLessThanOp>(
181 location: loc, args: rewriter.getI1Type(), args&: newIndVar, args: adaptor.getUpperBound());
182
183 rewriter.create<spirv::BranchConditionalOp>(
184 location: loc, args&: cmpOp, args&: body, args: ArrayRef<Value>(), args&: mergeBlock, args: ArrayRef<Value>());
185
186 // Generate instructions to increment the step of the induction variable and
187 // branch to the header.
188 Block *continueBlock = loopOp.getContinueBlock();
189 rewriter.setInsertionPointToEnd(continueBlock);
190
191 // Add the step to the induction variable and branch to the header.
192 Value updatedIndVar = rewriter.create<spirv::IAddOp>(
193 location: loc, args: newIndVar.getType(), args&: newIndVar, args: adaptor.getStep());
194 rewriter.create<spirv::BranchOp>(location: loc, args&: header, args&: updatedIndVar);
195
196 // Infer the return types from the init operands. Vector type may get
197 // converted to CooperativeMatrix or to Vector type, to avoid having complex
198 // extra logic to figure out the right type we just infer it from the Init
199 // operands.
200 SmallVector<Type, 8> initTypes;
201 for (auto arg : adaptor.getInitArgs())
202 initTypes.push_back(Elt: arg.getType());
203 replaceSCFOutputValue(scfOp: forOp, newOp: loopOp, rewriter, scfToSPIRVContext,
204 returnTypes: initTypes);
205 return success();
206 }
207};
208
209//===----------------------------------------------------------------------===//
210// scf::IfOp
211//===----------------------------------------------------------------------===//
212
213/// Pattern to convert a scf::IfOp within kernel functions into
214/// spirv::SelectionOp.
215struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
216 using SCFToSPIRVPattern::SCFToSPIRVPattern;
217
218 LogicalResult
219 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const override {
221 // When lowering `scf::IfOp` we explicitly create a selection header block
222 // before the control flow diverges and a merge block where control flow
223 // subsequently converges.
224 auto loc = ifOp.getLoc();
225
226 // Compute return types.
227 SmallVector<Type, 8> returnTypes;
228 for (auto result : ifOp.getResults()) {
229 auto convertedType = typeConverter.convertType(t: result.getType());
230 if (!convertedType)
231 return rewriter.notifyMatchFailure(
232 arg&: loc,
233 msg: llvm::formatv(Fmt: "failed to convert type '{0}'", Vals: result.getType()));
234
235 returnTypes.push_back(Elt: convertedType);
236 }
237
238 // Create `spirv.selection` operation, selection header block and merge
239 // block.
240 auto selectionOp =
241 rewriter.create<spirv::SelectionOp>(location: loc, args: spirv::SelectionControl::None);
242 auto *mergeBlock = rewriter.createBlock(parent: &selectionOp.getBody(),
243 insertPt: selectionOp.getBody().end());
244 rewriter.create<spirv::MergeOp>(location: loc);
245
246 OpBuilder::InsertionGuard guard(rewriter);
247 auto *selectionHeaderBlock =
248 rewriter.createBlock(insertBefore: &selectionOp.getBody().front());
249
250 // Inline `then` region before the merge block and branch to it.
251 auto &thenRegion = ifOp.getThenRegion();
252 auto *thenBlock = &thenRegion.front();
253 rewriter.setInsertionPointToEnd(&thenRegion.back());
254 rewriter.create<spirv::BranchOp>(location: loc, args&: mergeBlock);
255 rewriter.inlineRegionBefore(region&: thenRegion, before: mergeBlock);
256
257 auto *elseBlock = mergeBlock;
258 // If `else` region is not empty, inline that region before the merge block
259 // and branch to it.
260 if (!ifOp.getElseRegion().empty()) {
261 auto &elseRegion = ifOp.getElseRegion();
262 elseBlock = &elseRegion.front();
263 rewriter.setInsertionPointToEnd(&elseRegion.back());
264 rewriter.create<spirv::BranchOp>(location: loc, args&: mergeBlock);
265 rewriter.inlineRegionBefore(region&: elseRegion, before: mergeBlock);
266 }
267
268 // Create a `spirv.BranchConditional` operation for selection header block.
269 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
270 rewriter.create<spirv::BranchConditionalOp>(location: loc, args: adaptor.getCondition(),
271 args&: thenBlock, args: ArrayRef<Value>(),
272 args&: elseBlock, args: ArrayRef<Value>());
273
274 replaceSCFOutputValue(scfOp: ifOp, newOp: selectionOp, rewriter, scfToSPIRVContext,
275 returnTypes);
276 return success();
277 }
278};
279
280//===----------------------------------------------------------------------===//
281// scf::YieldOp
282//===----------------------------------------------------------------------===//
283
284struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
285public:
286 using SCFToSPIRVPattern::SCFToSPIRVPattern;
287
288 LogicalResult
289 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter) const override {
291 ValueRange operands = adaptor.getOperands();
292
293 Operation *parent = terminatorOp->getParentOp();
294
295 // TODO: Implement conversion for the remaining `scf` ops.
296 if (parent->getDialect()->getNamespace() ==
297 scf::SCFDialect::getDialectNamespace() &&
298 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(Val: parent))
299 return rewriter.notifyMatchFailure(
300 arg&: terminatorOp,
301 msg: llvm::formatv(Fmt: "conversion not supported for parent op: '{0}'",
302 Vals: parent->getName()));
303
304 // If the region return values, store each value into the associated
305 // VariableOp created during lowering of the parent region.
306 if (!operands.empty()) {
307 auto &allocas = scfToSPIRVContext->outputVars[parent];
308 if (allocas.size() != operands.size())
309 return failure();
310
311 auto loc = terminatorOp.getLoc();
312 for (unsigned i = 0, e = operands.size(); i < e; i++)
313 rewriter.create<spirv::StoreOp>(location: loc, args&: allocas[i], args: operands[i]);
314 if (isa<spirv::LoopOp>(Val: parent)) {
315 // For loops we also need to update the branch jumping back to the
316 // header.
317 auto br = cast<spirv::BranchOp>(
318 Val: rewriter.getInsertionBlock()->getTerminator());
319 SmallVector<Value, 8> args(br.getBlockArguments());
320 args.append(in_start: operands.begin(), in_end: operands.end());
321 rewriter.setInsertionPoint(br);
322 rewriter.create<spirv::BranchOp>(location: terminatorOp.getLoc(), args: br.getTarget(),
323 args);
324 rewriter.eraseOp(op: br);
325 }
326 }
327 rewriter.eraseOp(op: terminatorOp);
328 return success();
329 }
330};
331
332//===----------------------------------------------------------------------===//
333// scf::WhileOp
334//===----------------------------------------------------------------------===//
335
336struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
337 using SCFToSPIRVPattern::SCFToSPIRVPattern;
338
339 LogicalResult
340 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const override {
342 auto loc = whileOp.getLoc();
343 auto loopOp = rewriter.create<spirv::LoopOp>(location: loc, args: spirv::LoopControl::None);
344 loopOp.addEntryAndMergeBlock(builder&: rewriter);
345
346 Region &beforeRegion = whileOp.getBefore();
347 Region &afterRegion = whileOp.getAfter();
348
349 if (failed(Result: rewriter.convertRegionTypes(region: &beforeRegion, converter: typeConverter)) ||
350 failed(Result: rewriter.convertRegionTypes(region: &afterRegion, converter: typeConverter)))
351 return rewriter.notifyMatchFailure(arg&: whileOp,
352 msg: "Failed to convert region types");
353
354 OpBuilder::InsertionGuard guard(rewriter);
355
356 Block &entryBlock = *loopOp.getEntryBlock();
357 Block &beforeBlock = beforeRegion.front();
358 Block &afterBlock = afterRegion.front();
359 Block &mergeBlock = *loopOp.getMergeBlock();
360
361 auto cond = cast<scf::ConditionOp>(Val: beforeBlock.getTerminator());
362 SmallVector<Value> condArgs;
363 if (failed(Result: rewriter.getRemappedValues(keys: cond.getArgs(), results&: condArgs)))
364 return failure();
365
366 Value conditionVal = rewriter.getRemappedValue(key: cond.getCondition());
367 if (!conditionVal)
368 return failure();
369
370 auto yield = cast<scf::YieldOp>(Val: afterBlock.getTerminator());
371 SmallVector<Value> yieldArgs;
372 if (failed(Result: rewriter.getRemappedValues(keys: yield.getResults(), results&: yieldArgs)))
373 return failure();
374
375 // Move the while before block as the initial loop header block.
376 rewriter.inlineRegionBefore(region&: beforeRegion, parent&: loopOp.getBody(),
377 before: getBlockIt(region&: loopOp.getBody(), index: 1));
378
379 // Move the while after block as the initial loop body block.
380 rewriter.inlineRegionBefore(region&: afterRegion, parent&: loopOp.getBody(),
381 before: getBlockIt(region&: loopOp.getBody(), index: 2));
382
383 // Jump from the loop entry block to the loop header block.
384 rewriter.setInsertionPointToEnd(&entryBlock);
385 rewriter.create<spirv::BranchOp>(location: loc, args: &beforeBlock, args: adaptor.getInits());
386
387 auto condLoc = cond.getLoc();
388
389 SmallVector<Value> resultValues(condArgs.size());
390
391 // For other SCF ops, the scf.yield op yields the value for the whole SCF
392 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
393 // local variables. But for the scf.while op, the scf.yield op yields a
394 // value for the before region, which may not matching the whole op's
395 // result. Instead, the scf.condition op returns values matching the whole
396 // op's results. So we need to create/load/store variables according to
397 // that.
398 for (const auto &it : llvm::enumerate(First&: condArgs)) {
399 auto res = it.value();
400 auto i = it.index();
401 auto pointerType =
402 spirv::PointerType::get(pointeeType: res.getType(), storageClass: spirv::StorageClass::Function);
403
404 // Create local variables before the scf.while op.
405 rewriter.setInsertionPoint(loopOp);
406 auto alloc = rewriter.create<spirv::VariableOp>(
407 location: condLoc, args&: pointerType, args: spirv::StorageClass::Function,
408 /*initializer=*/args: nullptr);
409
410 // Load the final result values after the scf.while op.
411 rewriter.setInsertionPointAfter(loopOp);
412 auto loadResult = rewriter.create<spirv::LoadOp>(location: condLoc, args&: alloc);
413 resultValues[i] = loadResult;
414
415 // Store the current iteration's result value.
416 rewriter.setInsertionPointToEnd(&beforeBlock);
417 rewriter.create<spirv::StoreOp>(location: condLoc, args&: alloc, args&: res);
418 }
419
420 rewriter.setInsertionPointToEnd(&beforeBlock);
421 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
422 op: cond, args&: conditionVal, args: &afterBlock, args&: condArgs, args: &mergeBlock, args: ValueRange());
423
424 // Convert the scf.yield op to a branch back to the header block.
425 rewriter.setInsertionPointToEnd(&afterBlock);
426 rewriter.replaceOpWithNewOp<spirv::BranchOp>(op: yield, args: &beforeBlock,
427 args&: yieldArgs);
428
429 rewriter.replaceOp(op: whileOp, newValues: resultValues);
430 return success();
431 }
432};
433} // namespace
434
435//===----------------------------------------------------------------------===//
436// Public API
437//===----------------------------------------------------------------------===//
438
439void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
440 ScfToSPIRVContext &scfToSPIRVContext,
441 RewritePatternSet &patterns) {
442 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
443 WhileOpConversion>(arg: patterns.getContext(), args: typeConverter,
444 args: scfToSPIRVContext.getImpl());
445}
446

source code of mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp