1//===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Tensor dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
14#include "../SPIRVCommon/Pattern.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/IR/AffineMap.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "tensor-to-spirv-pattern"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Operation conversion
29//===----------------------------------------------------------------------===//
30
31namespace {
32
33/// Converts tensor.extract into loading using access chains from SPIR-V local
34/// variables.
35class TensorExtractPattern final
36 : public OpConversionPattern<tensor::ExtractOp> {
37public:
38 TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
39 int64_t threshold, PatternBenefit benefit = 1)
40 : OpConversionPattern(typeConverter, context, benefit),
41 byteCountThreshold(threshold) {}
42
43 LogicalResult
44 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter) const override {
46 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
47
48 if (!isa<spirv::ScalarType>(tensorType.getElementType()))
49 return rewriter.notifyMatchFailure(extractOp, "unsupported type");
50 if (!tensorType.hasStaticShape())
51 return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
52
53 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
54 byteCountThreshold * 8)
55 return rewriter.notifyMatchFailure(extractOp,
56 "exceeding byte count threshold");
57
58 Location loc = extractOp.getLoc();
59
60 int64_t rank = tensorType.getRank();
61 SmallVector<int64_t, 4> strides(rank, 1);
62 for (int i = rank - 2; i >= 0; --i) {
63 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
64 }
65
66 Type varType = spirv::PointerType::get(adaptor.getTensor().getType(),
67 spirv::StorageClass::Function);
68
69 spirv::VariableOp varOp;
70 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
71 // We could use the initializer directly; but certain driver compilers
72 // have bugs dealing with that. So for now, use spirv.Store for
73 // initialization.
74 varOp = rewriter.create<spirv::VariableOp>(loc, varType,
75 spirv::StorageClass::Function,
76 /*initializer=*/nullptr);
77 rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
78 } else {
79 // Need to store the value to the local variable. It's questionable
80 // whether we want to support such case though.
81 return failure();
82 }
83
84 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
85 auto indexType = typeConverter.getIndexType();
86
87 Value index = spirv::linearizeIndex(indices: adaptor.getIndices(), strides,
88 /*offset=*/0, integerType: indexType, loc, builder&: rewriter);
89 auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
90
91 rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
92
93 return success();
94 }
95
96private:
97 int64_t byteCountThreshold;
98};
99
100} // namespace
101
102//===----------------------------------------------------------------------===//
103// Pattern population
104//===----------------------------------------------------------------------===//
105
106void mlir::populateTensorToSPIRVPatterns(
107 const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
108 RewritePatternSet &patterns) {
109 patterns.add<TensorExtractPattern>(arg: typeConverter, args: patterns.getContext(),
110 args&: byteCountThreshold);
111}
112

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp