1//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SCF dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/Support/FormatVariadic.h"
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Context
26//===----------------------------------------------------------------------===//
27
28namespace mlir {
29struct ScfToSPIRVContextImpl {
30 // Map between the spirv region control flow operation (spirv.mlir.loop or
31 // spirv.mlir.selection) to the VariableOp created to store the region
32 // results. The order of the VariableOp matches the order of the results.
33 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
34};
35} // namespace mlir
36
37/// We use ScfToSPIRVContext to store information about the lowering of the scf
38/// region that need to be used later on. When we lower scf.for/scf.if we create
39/// VariableOp to store the results. We need to keep track of the VariableOp
40/// created as we need to insert stores into them when lowering Yield. Those
41/// StoreOp cannot be created earlier as they may use a different type than
42/// yield operands.
43ScfToSPIRVContext::ScfToSPIRVContext() {
44 impl = std::make_unique<::ScfToSPIRVContextImpl>();
45}
46
47ScfToSPIRVContext::~ScfToSPIRVContext() = default;
48
49namespace {
50
51//===----------------------------------------------------------------------===//
52// Helper Functions
53//===----------------------------------------------------------------------===//
54
55/// Replaces SCF op outputs with SPIR-V variable loads.
56/// We create VariableOp to handle the results value of the control flow region.
57/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
58/// after the loop we load the value from the allocation and use it as the SCF
59/// op result.
60template <typename ScfOp, typename OpTy>
61void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
62 ConversionPatternRewriter &rewriter,
63 ScfToSPIRVContextImpl *scfToSPIRVContext,
64 ArrayRef<Type> returnTypes) {
65
66 Location loc = scfOp.getLoc();
67 auto &allocas = scfToSPIRVContext->outputVars[newOp];
68 // Clearing the allocas is necessary in case a dialect conversion path failed
69 // previously, and this is the second attempt of this conversion.
70 allocas.clear();
71 SmallVector<Value, 8> resultValue;
72 for (Type convertedType : returnTypes) {
73 auto pointerType =
74 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
75 rewriter.setInsertionPoint(newOp);
76 auto alloc = rewriter.create<spirv::VariableOp>(
77 loc, pointerType, spirv::StorageClass::Function,
78 /*initializer=*/nullptr);
79 allocas.push_back(alloc);
80 rewriter.setInsertionPointAfter(newOp);
81 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
82 resultValue.push_back(Elt: loadResult);
83 }
84 rewriter.replaceOp(scfOp, resultValue);
85}
86
87Region::iterator getBlockIt(Region &region, unsigned index) {
88 return std::next(x: region.begin(), n: index);
89}
90
91//===----------------------------------------------------------------------===//
92// Conversion Patterns
93//===----------------------------------------------------------------------===//
94
95/// Common class for all vector to GPU patterns.
96template <typename OpTy>
97class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
98public:
99 SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
100 ScfToSPIRVContextImpl *scfToSPIRVContext)
101 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
102 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
103
104protected:
105 ScfToSPIRVContextImpl *scfToSPIRVContext;
106 // FIXME: We explicitly keep a reference of the type converter here instead of
107 // passing it to OpConversionPattern during construction. This effectively
108 // bypasses the conversion framework's automation on type conversion. This is
109 // needed right now because the conversion framework will unconditionally
110 // legalize all types used by SCF ops upon discovering them, for example, the
111 // types of loop carried values. We use SPIR-V variables for those loop
112 // carried values. Depending on the available capabilities, the SPIR-V
113 // variable can be different, for example, cooperative matrix or normal
114 // variable. We'd like to detach the conversion of the loop carried values
115 // from the SCF ops (which is mainly a region). So we need to "mark" types
116 // used by SCF ops as legal, if to use the conversion framework for type
117 // conversion. There isn't a straightforward way to do that yet, as when
118 // converting types, ops aren't taken into consideration. Therefore, we just
119 // bypass the framework's type conversion for now.
120 const SPIRVTypeConverter &typeConverter;
121};
122
123//===----------------------------------------------------------------------===//
124// scf::ForOp
125//===----------------------------------------------------------------------===//
126
127/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
128struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
129 using SCFToSPIRVPattern::SCFToSPIRVPattern;
130
131 LogicalResult
132 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 // scf::ForOp can be lowered to the structured control flow represented by
135 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
136 // latch and the merge block the exit block. The resulting spirv::LoopOp has
137 // a single back edge from the continue to header block, and a single exit
138 // from header to merge.
139 auto loc = forOp.getLoc();
140 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
141 loopOp.addEntryAndMergeBlock(rewriter);
142
143 OpBuilder::InsertionGuard guard(rewriter);
144 // Create the block for the header.
145 Block *header = rewriter.createBlock(&loopOp.getBody(),
146 getBlockIt(loopOp.getBody(), 1));
147 rewriter.setInsertionPointAfter(loopOp);
148
149 // Create the new induction variable to use.
150 Value adapLowerBound = adaptor.getLowerBound();
151 BlockArgument newIndVar =
152 header->addArgument(type: adapLowerBound.getType(), loc: adapLowerBound.getLoc());
153 for (Value arg : adaptor.getInitArgs())
154 header->addArgument(arg.getType(), arg.getLoc());
155 Block *body = forOp.getBody();
156
157 // Apply signature conversion to the body of the forOp. It has a single
158 // block, with argument which is the induction variable. That has to be
159 // replaced with the new induction variable.
160 TypeConverter::SignatureConversion signatureConverter(
161 body->getNumArguments());
162 signatureConverter.remapInput(origInputNo: 0, replacements: newIndVar);
163 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
164 signatureConverter.remapInput(origInputNo: i, replacements: header->getArgument(i));
165 body = rewriter.applySignatureConversion(block: &forOp.getRegion().front(),
166 conversion&: signatureConverter);
167
168 // Move the blocks from the forOp into the loopOp. This is the body of the
169 // loopOp.
170 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
171 getBlockIt(loopOp.getBody(), 2));
172
173 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
174 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
175 // Branch into it from the entry.
176 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
177 rewriter.create<spirv::BranchOp>(loc, header, args);
178
179 // Generate the rest of the loop header.
180 rewriter.setInsertionPointToEnd(header);
181 auto *mergeBlock = loopOp.getMergeBlock();
182 auto cmpOp = rewriter.create<spirv::SLessThanOp>(
183 loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
184
185 rewriter.create<spirv::BranchConditionalOp>(
186 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
187
188 // Generate instructions to increment the step of the induction variable and
189 // branch to the header.
190 Block *continueBlock = loopOp.getContinueBlock();
191 rewriter.setInsertionPointToEnd(continueBlock);
192
193 // Add the step to the induction variable and branch to the header.
194 Value updatedIndVar = rewriter.create<spirv::IAddOp>(
195 loc, newIndVar.getType(), newIndVar, adaptor.getStep());
196 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
197
198 // Infer the return types from the init operands. Vector type may get
199 // converted to CooperativeMatrix or to Vector type, to avoid having complex
200 // extra logic to figure out the right type we just infer it from the Init
201 // operands.
202 SmallVector<Type, 8> initTypes;
203 for (auto arg : adaptor.getInitArgs())
204 initTypes.push_back(arg.getType());
205 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
206 initTypes);
207 return success();
208 }
209};
210
211//===----------------------------------------------------------------------===//
212// scf::IfOp
213//===----------------------------------------------------------------------===//
214
215/// Pattern to convert a scf::IfOp within kernel functions into
216/// spirv::SelectionOp.
217struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
218 using SCFToSPIRVPattern::SCFToSPIRVPattern;
219
220 LogicalResult
221 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223 // When lowering `scf::IfOp` we explicitly create a selection header block
224 // before the control flow diverges and a merge block where control flow
225 // subsequently converges.
226 auto loc = ifOp.getLoc();
227
228 // Compute return types.
229 SmallVector<Type, 8> returnTypes;
230 for (auto result : ifOp.getResults()) {
231 auto convertedType = typeConverter.convertType(result.getType());
232 if (!convertedType)
233 return rewriter.notifyMatchFailure(
234 loc,
235 llvm::formatv("failed to convert type '{0}'", result.getType()));
236
237 returnTypes.push_back(convertedType);
238 }
239
240 // Create `spirv.selection` operation, selection header block and merge
241 // block.
242 auto selectionOp =
243 rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
244 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
245 selectionOp.getBody().end());
246 rewriter.create<spirv::MergeOp>(loc);
247
248 OpBuilder::InsertionGuard guard(rewriter);
249 auto *selectionHeaderBlock =
250 rewriter.createBlock(&selectionOp.getBody().front());
251
252 // Inline `then` region before the merge block and branch to it.
253 auto &thenRegion = ifOp.getThenRegion();
254 auto *thenBlock = &thenRegion.front();
255 rewriter.setInsertionPointToEnd(&thenRegion.back());
256 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
257 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
258
259 auto *elseBlock = mergeBlock;
260 // If `else` region is not empty, inline that region before the merge block
261 // and branch to it.
262 if (!ifOp.getElseRegion().empty()) {
263 auto &elseRegion = ifOp.getElseRegion();
264 elseBlock = &elseRegion.front();
265 rewriter.setInsertionPointToEnd(&elseRegion.back());
266 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
267 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
268 }
269
270 // Create a `spirv.BranchConditional` operation for selection header block.
271 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
272 rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
273 thenBlock, ArrayRef<Value>(),
274 elseBlock, ArrayRef<Value>());
275
276 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
277 returnTypes);
278 return success();
279 }
280};
281
282//===----------------------------------------------------------------------===//
283// scf::YieldOp
284//===----------------------------------------------------------------------===//
285
286struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
287public:
288 using SCFToSPIRVPattern::SCFToSPIRVPattern;
289
290 LogicalResult
291 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
292 ConversionPatternRewriter &rewriter) const override {
293 ValueRange operands = adaptor.getOperands();
294
295 Operation *parent = terminatorOp->getParentOp();
296
297 // TODO: Implement conversion for the remaining `scf` ops.
298 if (parent->getDialect()->getNamespace() ==
299 scf::SCFDialect::getDialectNamespace() &&
300 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
301 return rewriter.notifyMatchFailure(
302 terminatorOp,
303 llvm::formatv(Fmt: "conversion not supported for parent op: '{0}'",
304 Vals: parent->getName()));
305
306 // If the region return values, store each value into the associated
307 // VariableOp created during lowering of the parent region.
308 if (!operands.empty()) {
309 auto &allocas = scfToSPIRVContext->outputVars[parent];
310 if (allocas.size() != operands.size())
311 return failure();
312
313 auto loc = terminatorOp.getLoc();
314 for (unsigned i = 0, e = operands.size(); i < e; i++)
315 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
316 if (isa<spirv::LoopOp>(parent)) {
317 // For loops we also need to update the branch jumping back to the
318 // header.
319 auto br = cast<spirv::BranchOp>(
320 rewriter.getInsertionBlock()->getTerminator());
321 SmallVector<Value, 8> args(br.getBlockArguments());
322 args.append(in_start: operands.begin(), in_end: operands.end());
323 rewriter.setInsertionPoint(br);
324 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
325 args);
326 rewriter.eraseOp(op: br);
327 }
328 }
329 rewriter.eraseOp(op: terminatorOp);
330 return success();
331 }
332};
333
334//===----------------------------------------------------------------------===//
335// scf::WhileOp
336//===----------------------------------------------------------------------===//
337
338struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
339 using SCFToSPIRVPattern::SCFToSPIRVPattern;
340
341 LogicalResult
342 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const override {
344 auto loc = whileOp.getLoc();
345 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
346 loopOp.addEntryAndMergeBlock(rewriter);
347
348 Region &beforeRegion = whileOp.getBefore();
349 Region &afterRegion = whileOp.getAfter();
350
351 if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
352 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
353 return rewriter.notifyMatchFailure(whileOp,
354 "Failed to convert region types");
355
356 OpBuilder::InsertionGuard guard(rewriter);
357
358 Block &entryBlock = *loopOp.getEntryBlock();
359 Block &beforeBlock = beforeRegion.front();
360 Block &afterBlock = afterRegion.front();
361 Block &mergeBlock = *loopOp.getMergeBlock();
362
363 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
364 SmallVector<Value> condArgs;
365 if (failed(rewriter.getRemappedValues(keys: cond.getArgs(), results&: condArgs)))
366 return failure();
367
368 Value conditionVal = rewriter.getRemappedValue(key: cond.getCondition());
369 if (!conditionVal)
370 return failure();
371
372 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
373 SmallVector<Value> yieldArgs;
374 if (failed(rewriter.getRemappedValues(keys: yield.getResults(), results&: yieldArgs)))
375 return failure();
376
377 // Move the while before block as the initial loop header block.
378 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
379 getBlockIt(loopOp.getBody(), 1));
380
381 // Move the while after block as the initial loop body block.
382 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
383 getBlockIt(loopOp.getBody(), 2));
384
385 // Jump from the loop entry block to the loop header block.
386 rewriter.setInsertionPointToEnd(&entryBlock);
387 rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
388
389 auto condLoc = cond.getLoc();
390
391 SmallVector<Value> resultValues(condArgs.size());
392
393 // For other SCF ops, the scf.yield op yields the value for the whole SCF
394 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
395 // local variables. But for the scf.while op, the scf.yield op yields a
396 // value for the before region, which may not matching the whole op's
397 // result. Instead, the scf.condition op returns values matching the whole
398 // op's results. So we need to create/load/store variables according to
399 // that.
400 for (const auto &it : llvm::enumerate(First&: condArgs)) {
401 auto res = it.value();
402 auto i = it.index();
403 auto pointerType =
404 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
405
406 // Create local variables before the scf.while op.
407 rewriter.setInsertionPoint(loopOp);
408 auto alloc = rewriter.create<spirv::VariableOp>(
409 condLoc, pointerType, spirv::StorageClass::Function,
410 /*initializer=*/nullptr);
411
412 // Load the final result values after the scf.while op.
413 rewriter.setInsertionPointAfter(loopOp);
414 auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
415 resultValues[i] = loadResult;
416
417 // Store the current iteration's result value.
418 rewriter.setInsertionPointToEnd(&beforeBlock);
419 rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
420 }
421
422 rewriter.setInsertionPointToEnd(&beforeBlock);
423 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
424 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
425
426 // Convert the scf.yield op to a branch back to the header block.
427 rewriter.setInsertionPointToEnd(&afterBlock);
428 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
429 yieldArgs);
430
431 rewriter.replaceOp(whileOp, resultValues);
432 return success();
433 }
434};
435} // namespace
436
437//===----------------------------------------------------------------------===//
438// Public API
439//===----------------------------------------------------------------------===//
440
441void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
442 ScfToSPIRVContext &scfToSPIRVContext,
443 RewritePatternSet &patterns) {
444 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
445 WhileOpConversion>(arg: patterns.getContext(), args: typeConverter,
446 args: scfToSPIRVContext.getImpl());
447}
448

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp