1//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SCF dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/Support/FormatVariadic.h"
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Context
26//===----------------------------------------------------------------------===//
27
28namespace mlir {
29struct ScfToSPIRVContextImpl {
30 // Map between the spirv region control flow operation (spirv.mlir.loop or
31 // spirv.mlir.selection) to the VariableOp created to store the region
32 // results. The order of the VariableOp matches the order of the results.
33 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
34};
35} // namespace mlir
36
37/// We use ScfToSPIRVContext to store information about the lowering of the scf
38/// region that need to be used later on. When we lower scf.for/scf.if we create
39/// VariableOp to store the results. We need to keep track of the VariableOp
40/// created as we need to insert stores into them when lowering Yield. Those
41/// StoreOp cannot be created earlier as they may use a different type than
42/// yield operands.
43ScfToSPIRVContext::ScfToSPIRVContext() {
44 impl = std::make_unique<::ScfToSPIRVContextImpl>();
45}
46
47ScfToSPIRVContext::~ScfToSPIRVContext() = default;
48
49namespace {
50
51//===----------------------------------------------------------------------===//
52// Helper Functions
53//===----------------------------------------------------------------------===//
54
55/// Replaces SCF op outputs with SPIR-V variable loads.
56/// We create VariableOp to handle the results value of the control flow region.
57/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
58/// after the loop we load the value from the allocation and use it as the SCF
59/// op result.
60template <typename ScfOp, typename OpTy>
61void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
62 ConversionPatternRewriter &rewriter,
63 ScfToSPIRVContextImpl *scfToSPIRVContext,
64 ArrayRef<Type> returnTypes) {
65
66 Location loc = scfOp.getLoc();
67 auto &allocas = scfToSPIRVContext->outputVars[newOp];
68 // Clearing the allocas is necessary in case a dialect conversion path failed
69 // previously, and this is the second attempt of this conversion.
70 allocas.clear();
71 SmallVector<Value, 8> resultValue;
72 for (Type convertedType : returnTypes) {
73 auto pointerType =
74 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
75 rewriter.setInsertionPoint(newOp);
76 auto alloc = rewriter.create<spirv::VariableOp>(
77 loc, pointerType, spirv::StorageClass::Function,
78 /*initializer=*/nullptr);
79 allocas.push_back(alloc);
80 rewriter.setInsertionPointAfter(newOp);
81 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
82 resultValue.push_back(Elt: loadResult);
83 }
84 rewriter.replaceOp(scfOp, resultValue);
85}
86
87Region::iterator getBlockIt(Region &region, unsigned index) {
88 return std::next(x: region.begin(), n: index);
89}
90
91//===----------------------------------------------------------------------===//
92// Conversion Patterns
93//===----------------------------------------------------------------------===//
94
95/// Common class for all vector to GPU patterns.
96template <typename OpTy>
97class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
98public:
99 SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter,
100 ScfToSPIRVContextImpl *scfToSPIRVContext)
101 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
102 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
103
104protected:
105 ScfToSPIRVContextImpl *scfToSPIRVContext;
106 // FIXME: We explicitly keep a reference of the type converter here instead of
107 // passing it to OpConversionPattern during construction. This effectively
108 // bypasses the conversion framework's automation on type conversion. This is
109 // needed right now because the conversion framework will unconditionally
110 // legalize all types used by SCF ops upon discovering them, for example, the
111 // types of loop carried values. We use SPIR-V variables for those loop
112 // carried values. Depending on the available capabilities, the SPIR-V
113 // variable can be different, for example, cooperative matrix or normal
114 // variable. We'd like to detach the conversion of the loop carried values
115 // from the SCF ops (which is mainly a region). So we need to "mark" types
116 // used by SCF ops as legal, if to use the conversion framework for type
117 // conversion. There isn't a straightforward way to do that yet, as when
118 // converting types, ops aren't taken into consideration. Therefore, we just
119 // bypass the framework's type conversion for now.
120 SPIRVTypeConverter &typeConverter;
121};
122
123//===----------------------------------------------------------------------===//
124// scf::ForOp
125//===----------------------------------------------------------------------===//
126
127/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
128struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
129 using SCFToSPIRVPattern::SCFToSPIRVPattern;
130
131 LogicalResult
132 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 // scf::ForOp can be lowered to the structured control flow represented by
135 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
136 // latch and the merge block the exit block. The resulting spirv::LoopOp has
137 // a single back edge from the continue to header block, and a single exit
138 // from header to merge.
139 auto loc = forOp.getLoc();
140 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
141 loopOp.addEntryAndMergeBlock(rewriter);
142
143 OpBuilder::InsertionGuard guard(rewriter);
144 // Create the block for the header.
145 Block *header = rewriter.createBlock(&loopOp.getBody(),
146 getBlockIt(loopOp.getBody(), 1));
147 rewriter.setInsertionPointAfter(loopOp);
148
149 // Create the new induction variable to use.
150 Value adapLowerBound = adaptor.getLowerBound();
151 BlockArgument newIndVar =
152 header->addArgument(type: adapLowerBound.getType(), loc: adapLowerBound.getLoc());
153 for (Value arg : adaptor.getInitArgs())
154 header->addArgument(arg.getType(), arg.getLoc());
155 Block *body = forOp.getBody();
156
157 // Apply signature conversion to the body of the forOp. It has a single
158 // block, with argument which is the induction variable. That has to be
159 // replaced with the new induction variable.
160 TypeConverter::SignatureConversion signatureConverter(
161 body->getNumArguments());
162 signatureConverter.remapInput(origInputNo: 0, replacement: newIndVar);
163 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
164 signatureConverter.remapInput(origInputNo: i, replacement: header->getArgument(i));
165 body = rewriter.applySignatureConversion(region: &forOp.getRegion(),
166 conversion&: signatureConverter);
167
168 // Move the blocks from the forOp into the loopOp. This is the body of the
169 // loopOp.
170 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
171 getBlockIt(loopOp.getBody(), 2));
172
173 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
174 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
175 // Branch into it from the entry.
176 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
177 rewriter.create<spirv::BranchOp>(loc, header, args);
178
179 // Generate the rest of the loop header.
180 rewriter.setInsertionPointToEnd(header);
181 auto *mergeBlock = loopOp.getMergeBlock();
182 auto cmpOp = rewriter.create<spirv::SLessThanOp>(
183 loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
184
185 rewriter.create<spirv::BranchConditionalOp>(
186 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
187
188 // Generate instructions to increment the step of the induction variable and
189 // branch to the header.
190 Block *continueBlock = loopOp.getContinueBlock();
191 rewriter.setInsertionPointToEnd(continueBlock);
192
193 // Add the step to the induction variable and branch to the header.
194 Value updatedIndVar = rewriter.create<spirv::IAddOp>(
195 loc, newIndVar.getType(), newIndVar, adaptor.getStep());
196 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
197
198 // Infer the return types from the init operands. Vector type may get
199 // converted to CooperativeMatrix or to Vector type, to avoid having complex
200 // extra logic to figure out the right type we just infer it from the Init
201 // operands.
202 SmallVector<Type, 8> initTypes;
203 for (auto arg : adaptor.getInitArgs())
204 initTypes.push_back(arg.getType());
205 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
206 initTypes);
207 return success();
208 }
209};
210
211//===----------------------------------------------------------------------===//
212// scf::IfOp
213//===----------------------------------------------------------------------===//
214
215/// Pattern to convert a scf::IfOp within kernel functions into
216/// spirv::SelectionOp.
217struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
218 using SCFToSPIRVPattern::SCFToSPIRVPattern;
219
220 LogicalResult
221 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223 // When lowering `scf::IfOp` we explicitly create a selection header block
224 // before the control flow diverges and a merge block where control flow
225 // subsequently converges.
226 auto loc = ifOp.getLoc();
227
228 // Create `spirv.selection` operation, selection header block and merge
229 // block.
230 auto selectionOp =
231 rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
232 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
233 selectionOp.getBody().end());
234 rewriter.create<spirv::MergeOp>(loc);
235
236 OpBuilder::InsertionGuard guard(rewriter);
237 auto *selectionHeaderBlock =
238 rewriter.createBlock(&selectionOp.getBody().front());
239
240 // Inline `then` region before the merge block and branch to it.
241 auto &thenRegion = ifOp.getThenRegion();
242 auto *thenBlock = &thenRegion.front();
243 rewriter.setInsertionPointToEnd(&thenRegion.back());
244 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
245 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
246
247 auto *elseBlock = mergeBlock;
248 // If `else` region is not empty, inline that region before the merge block
249 // and branch to it.
250 if (!ifOp.getElseRegion().empty()) {
251 auto &elseRegion = ifOp.getElseRegion();
252 elseBlock = &elseRegion.front();
253 rewriter.setInsertionPointToEnd(&elseRegion.back());
254 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
255 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
256 }
257
258 // Create a `spirv.BranchConditional` operation for selection header block.
259 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
260 rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
261 thenBlock, ArrayRef<Value>(),
262 elseBlock, ArrayRef<Value>());
263
264 SmallVector<Type, 8> returnTypes;
265 for (auto result : ifOp.getResults()) {
266 auto convertedType = typeConverter.convertType(result.getType());
267 if (!convertedType)
268 return rewriter.notifyMatchFailure(
269 loc,
270 llvm::formatv("failed to convert type '{0}'", result.getType()));
271
272 returnTypes.push_back(convertedType);
273 }
274 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
275 returnTypes);
276 return success();
277 }
278};
279
280//===----------------------------------------------------------------------===//
281// scf::YieldOp
282//===----------------------------------------------------------------------===//
283
284struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
285public:
286 using SCFToSPIRVPattern::SCFToSPIRVPattern;
287
288 LogicalResult
289 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter) const override {
291 ValueRange operands = adaptor.getOperands();
292
293 Operation *parent = terminatorOp->getParentOp();
294
295 // TODO: Implement conversion for the remaining `scf` ops.
296 if (parent->getDialect()->getNamespace() ==
297 scf::SCFDialect::getDialectNamespace() &&
298 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
299 return rewriter.notifyMatchFailure(
300 terminatorOp,
301 llvm::formatv(Fmt: "conversion not supported for parent op: '{0}'",
302 Vals: parent->getName()));
303
304 // If the region return values, store each value into the associated
305 // VariableOp created during lowering of the parent region.
306 if (!operands.empty()) {
307 auto &allocas = scfToSPIRVContext->outputVars[parent];
308 if (allocas.size() != operands.size())
309 return failure();
310
311 auto loc = terminatorOp.getLoc();
312 for (unsigned i = 0, e = operands.size(); i < e; i++)
313 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
314 if (isa<spirv::LoopOp>(parent)) {
315 // For loops we also need to update the branch jumping back to the
316 // header.
317 auto br = cast<spirv::BranchOp>(
318 rewriter.getInsertionBlock()->getTerminator());
319 SmallVector<Value, 8> args(br.getBlockArguments());
320 args.append(in_start: operands.begin(), in_end: operands.end());
321 rewriter.setInsertionPoint(br);
322 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
323 args);
324 rewriter.eraseOp(op: br);
325 }
326 }
327 rewriter.eraseOp(op: terminatorOp);
328 return success();
329 }
330};
331
332//===----------------------------------------------------------------------===//
333// scf::WhileOp
334//===----------------------------------------------------------------------===//
335
336struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
337 using SCFToSPIRVPattern::SCFToSPIRVPattern;
338
339 LogicalResult
340 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const override {
342 auto loc = whileOp.getLoc();
343 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
344 loopOp.addEntryAndMergeBlock(rewriter);
345
346 Region &beforeRegion = whileOp.getBefore();
347 Region &afterRegion = whileOp.getAfter();
348
349 if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
350 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
351 return rewriter.notifyMatchFailure(whileOp,
352 "Failed to convert region types");
353
354 OpBuilder::InsertionGuard guard(rewriter);
355
356 Block &entryBlock = *loopOp.getEntryBlock();
357 Block &beforeBlock = beforeRegion.front();
358 Block &afterBlock = afterRegion.front();
359 Block &mergeBlock = *loopOp.getMergeBlock();
360
361 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
362 SmallVector<Value> condArgs;
363 if (failed(rewriter.getRemappedValues(keys: cond.getArgs(), results&: condArgs)))
364 return failure();
365
366 Value conditionVal = rewriter.getRemappedValue(key: cond.getCondition());
367 if (!conditionVal)
368 return failure();
369
370 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
371 SmallVector<Value> yieldArgs;
372 if (failed(rewriter.getRemappedValues(keys: yield.getResults(), results&: yieldArgs)))
373 return failure();
374
375 // Move the while before block as the initial loop header block.
376 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
377 getBlockIt(loopOp.getBody(), 1));
378
379 // Move the while after block as the initial loop body block.
380 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
381 getBlockIt(loopOp.getBody(), 2));
382
383 // Jump from the loop entry block to the loop header block.
384 rewriter.setInsertionPointToEnd(&entryBlock);
385 rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
386
387 auto condLoc = cond.getLoc();
388
389 SmallVector<Value> resultValues(condArgs.size());
390
391 // For other SCF ops, the scf.yield op yields the value for the whole SCF
392 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
393 // local variables. But for the scf.while op, the scf.yield op yields a
394 // value for the before region, which may not matching the whole op's
395 // result. Instead, the scf.condition op returns values matching the whole
396 // op's results. So we need to create/load/store variables according to
397 // that.
398 for (const auto &it : llvm::enumerate(First&: condArgs)) {
399 auto res = it.value();
400 auto i = it.index();
401 auto pointerType =
402 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
403
404 // Create local variables before the scf.while op.
405 rewriter.setInsertionPoint(loopOp);
406 auto alloc = rewriter.create<spirv::VariableOp>(
407 condLoc, pointerType, spirv::StorageClass::Function,
408 /*initializer=*/nullptr);
409
410 // Load the final result values after the scf.while op.
411 rewriter.setInsertionPointAfter(loopOp);
412 auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
413 resultValues[i] = loadResult;
414
415 // Store the current iteration's result value.
416 rewriter.setInsertionPointToEnd(&beforeBlock);
417 rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
418 }
419
420 rewriter.setInsertionPointToEnd(&beforeBlock);
421 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
422 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
423
424 // Convert the scf.yield op to a branch back to the header block.
425 rewriter.setInsertionPointToEnd(&afterBlock);
426 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
427 yieldArgs);
428
429 rewriter.replaceOp(whileOp, resultValues);
430 return success();
431 }
432};
433} // namespace
434
435//===----------------------------------------------------------------------===//
436// Public API
437//===----------------------------------------------------------------------===//
438
439void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
440 ScfToSPIRVContext &scfToSPIRVContext,
441 RewritePatternSet &patterns) {
442 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
443 WhileOpConversion>(arg: patterns.getContext(), args&: typeConverter,
444 args: scfToSPIRVContext.getImpl());
445}
446

source code of mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp