1//===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Tensor dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
14#include "../SPIRVCommon/Pattern.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/Support/LogicalResult.h"
22#include "llvm/Support/Debug.h"
23
24#define DEBUG_TYPE "tensor-to-spirv-pattern"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Operation conversion
30//===----------------------------------------------------------------------===//
31
32namespace {
33
34/// Converts tensor.extract into loading using access chains from SPIR-V local
35/// variables.
36class TensorExtractPattern final
37 : public OpConversionPattern<tensor::ExtractOp> {
38public:
39 TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
40 int64_t threshold, PatternBenefit benefit = 1)
41 : OpConversionPattern(typeConverter, context, benefit),
42 byteCountThreshold(threshold) {}
43
44 LogicalResult
45 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
46 ConversionPatternRewriter &rewriter) const override {
47 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
48
49 if (!tensorType.hasStaticShape())
50 return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
51
52 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
53 byteCountThreshold * 8)
54 return rewriter.notifyMatchFailure(extractOp,
55 "exceeding byte count threshold");
56
57 Location loc = extractOp.getLoc();
58
59 int64_t rank = tensorType.getRank();
60 SmallVector<int64_t, 4> strides(rank, 1);
61 for (int i = rank - 2; i >= 0; --i) {
62 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
63 }
64
65 Type varType = spirv::PointerType::get(adaptor.getTensor().getType(),
66 spirv::StorageClass::Function);
67
68 spirv::VariableOp varOp;
69 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
70 // We could use the initializer directly; but certain driver compilers
71 // have bugs dealing with that. So for now, use spirv.Store for
72 // initialization.
73 varOp = rewriter.create<spirv::VariableOp>(loc, varType,
74 spirv::StorageClass::Function,
75 /*initializer=*/nullptr);
76 rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
77 } else {
78 // Need to store the value to the local variable. It's questionable
79 // whether we want to support such case though.
80 return failure();
81 }
82
83 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
84 auto indexType = typeConverter.getIndexType();
85
86 Value index = spirv::linearizeIndex(indices: adaptor.getIndices(), strides,
87 /*offset=*/0, integerType: indexType, loc, builder&: rewriter);
88 auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
89
90 rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
91
92 return success();
93 }
94
95private:
96 int64_t byteCountThreshold;
97};
98
99} // namespace
100
101//===----------------------------------------------------------------------===//
102// Pattern population
103//===----------------------------------------------------------------------===//
104
105void mlir::populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
106 int64_t byteCountThreshold,
107 RewritePatternSet &patterns) {
108 patterns.add<TensorExtractPattern>(arg&: typeConverter, args: patterns.getContext(),
109 args&: byteCountThreshold);
110}
111

source code of mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp