1//===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Tensor dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/IR/AffineMap.h"
19
20#define DEBUG_TYPE "tensor-to-spirv-pattern"
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Operation conversion
26//===----------------------------------------------------------------------===//
27
28namespace {
29
30/// Converts tensor.extract into loading using access chains from SPIR-V local
31/// variables.
32class TensorExtractPattern final
33 : public OpConversionPattern<tensor::ExtractOp> {
34public:
35 TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
36 int64_t threshold, PatternBenefit benefit = 1)
37 : OpConversionPattern(typeConverter, context, benefit),
38 byteCountThreshold(threshold) {}
39
40 LogicalResult
41 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter) const override {
43 auto tensorType = cast<RankedTensorType>(Val: extractOp.getTensor().getType());
44
45 if (!isa<spirv::ScalarType>(Val: tensorType.getElementType()))
46 return rewriter.notifyMatchFailure(arg&: extractOp, msg: "unsupported type");
47 if (!tensorType.hasStaticShape())
48 return rewriter.notifyMatchFailure(arg&: extractOp, msg: "non-static tensor");
49
50 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
51 byteCountThreshold * 8)
52 return rewriter.notifyMatchFailure(arg&: extractOp,
53 msg: "exceeding byte count threshold");
54
55 Location loc = extractOp.getLoc();
56
57 int64_t rank = tensorType.getRank();
58 SmallVector<int64_t, 4> strides(rank, 1);
59 for (int i = rank - 2; i >= 0; --i) {
60 strides[i] = strides[i + 1] * tensorType.getDimSize(idx: i + 1);
61 }
62
63 Type varType = spirv::PointerType::get(pointeeType: adaptor.getTensor().getType(),
64 storageClass: spirv::StorageClass::Function);
65
66 spirv::VariableOp varOp;
67 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
68 // We could use the initializer directly; but certain driver compilers
69 // have bugs dealing with that. So for now, use spirv.Store for
70 // initialization.
71 varOp = rewriter.create<spirv::VariableOp>(location: loc, args&: varType,
72 args: spirv::StorageClass::Function,
73 /*initializer=*/args: nullptr);
74 rewriter.create<spirv::StoreOp>(location: loc, args&: varOp, args: adaptor.getTensor());
75 } else {
76 // Need to store the value to the local variable. It's questionable
77 // whether we want to support such case though.
78 return failure();
79 }
80
81 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
82 auto indexType = typeConverter.getIndexType();
83
84 Value index = spirv::linearizeIndex(indices: adaptor.getIndices(), strides,
85 /*offset=*/0, integerType: indexType, loc, builder&: rewriter);
86 auto acOp = rewriter.create<spirv::AccessChainOp>(location: loc, args&: varOp, args&: index);
87
88 rewriter.replaceOpWithNewOp<spirv::LoadOp>(op: extractOp, args&: acOp);
89
90 return success();
91 }
92
93private:
94 int64_t byteCountThreshold;
95};
96
97} // namespace
98
99//===----------------------------------------------------------------------===//
100// Pattern population
101//===----------------------------------------------------------------------===//
102
103void mlir::populateTensorToSPIRVPatterns(
104 const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
105 RewritePatternSet &patterns) {
106 patterns.add<TensorExtractPattern>(arg: typeConverter, args: patterns.getContext(),
107 args&: byteCountThreshold);
108}
109

source code of mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp