1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* |
3 | * Cryptographic API. |
4 | * |
5 | * Copyright (c) 2017-present, Facebook, Inc. |
6 | */ |
7 | #include <linux/crypto.h> |
8 | #include <linux/init.h> |
9 | #include <linux/interrupt.h> |
10 | #include <linux/mm.h> |
11 | #include <linux/module.h> |
12 | #include <linux/net.h> |
13 | #include <linux/vmalloc.h> |
14 | #include <linux/zstd.h> |
15 | #include <crypto/internal/scompress.h> |
16 | |
17 | |
18 | #define ZSTD_DEF_LEVEL 3 |
19 | |
20 | struct zstd_ctx { |
21 | zstd_cctx *cctx; |
22 | zstd_dctx *dctx; |
23 | void *cwksp; |
24 | void *dwksp; |
25 | }; |
26 | |
27 | static zstd_parameters zstd_params(void) |
28 | { |
29 | return zstd_get_params(ZSTD_DEF_LEVEL, estimated_src_size: 0); |
30 | } |
31 | |
32 | static int zstd_comp_init(struct zstd_ctx *ctx) |
33 | { |
34 | int ret = 0; |
35 | const zstd_parameters params = zstd_params(); |
36 | const size_t wksp_size = zstd_cctx_workspace_bound(parameters: ¶ms.cParams); |
37 | |
38 | ctx->cwksp = vzalloc(size: wksp_size); |
39 | if (!ctx->cwksp) { |
40 | ret = -ENOMEM; |
41 | goto out; |
42 | } |
43 | |
44 | ctx->cctx = zstd_init_cctx(workspace: ctx->cwksp, workspace_size: wksp_size); |
45 | if (!ctx->cctx) { |
46 | ret = -EINVAL; |
47 | goto out_free; |
48 | } |
49 | out: |
50 | return ret; |
51 | out_free: |
52 | vfree(addr: ctx->cwksp); |
53 | goto out; |
54 | } |
55 | |
56 | static int zstd_decomp_init(struct zstd_ctx *ctx) |
57 | { |
58 | int ret = 0; |
59 | const size_t wksp_size = zstd_dctx_workspace_bound(); |
60 | |
61 | ctx->dwksp = vzalloc(size: wksp_size); |
62 | if (!ctx->dwksp) { |
63 | ret = -ENOMEM; |
64 | goto out; |
65 | } |
66 | |
67 | ctx->dctx = zstd_init_dctx(workspace: ctx->dwksp, workspace_size: wksp_size); |
68 | if (!ctx->dctx) { |
69 | ret = -EINVAL; |
70 | goto out_free; |
71 | } |
72 | out: |
73 | return ret; |
74 | out_free: |
75 | vfree(addr: ctx->dwksp); |
76 | goto out; |
77 | } |
78 | |
79 | static void zstd_comp_exit(struct zstd_ctx *ctx) |
80 | { |
81 | vfree(addr: ctx->cwksp); |
82 | ctx->cwksp = NULL; |
83 | ctx->cctx = NULL; |
84 | } |
85 | |
86 | static void zstd_decomp_exit(struct zstd_ctx *ctx) |
87 | { |
88 | vfree(addr: ctx->dwksp); |
89 | ctx->dwksp = NULL; |
90 | ctx->dctx = NULL; |
91 | } |
92 | |
93 | static int __zstd_init(void *ctx) |
94 | { |
95 | int ret; |
96 | |
97 | ret = zstd_comp_init(ctx); |
98 | if (ret) |
99 | return ret; |
100 | ret = zstd_decomp_init(ctx); |
101 | if (ret) |
102 | zstd_comp_exit(ctx); |
103 | return ret; |
104 | } |
105 | |
106 | static void *zstd_alloc_ctx(struct crypto_scomp *tfm) |
107 | { |
108 | int ret; |
109 | struct zstd_ctx *ctx; |
110 | |
111 | ctx = kzalloc(size: sizeof(*ctx), GFP_KERNEL); |
112 | if (!ctx) |
113 | return ERR_PTR(error: -ENOMEM); |
114 | |
115 | ret = __zstd_init(ctx); |
116 | if (ret) { |
117 | kfree(objp: ctx); |
118 | return ERR_PTR(error: ret); |
119 | } |
120 | |
121 | return ctx; |
122 | } |
123 | |
124 | static int zstd_init(struct crypto_tfm *tfm) |
125 | { |
126 | struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
127 | |
128 | return __zstd_init(ctx); |
129 | } |
130 | |
131 | static void __zstd_exit(void *ctx) |
132 | { |
133 | zstd_comp_exit(ctx); |
134 | zstd_decomp_exit(ctx); |
135 | } |
136 | |
137 | static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx) |
138 | { |
139 | __zstd_exit(ctx); |
140 | kfree_sensitive(objp: ctx); |
141 | } |
142 | |
143 | static void zstd_exit(struct crypto_tfm *tfm) |
144 | { |
145 | struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
146 | |
147 | __zstd_exit(ctx); |
148 | } |
149 | |
150 | static int __zstd_compress(const u8 *src, unsigned int slen, |
151 | u8 *dst, unsigned int *dlen, void *ctx) |
152 | { |
153 | size_t out_len; |
154 | struct zstd_ctx *zctx = ctx; |
155 | const zstd_parameters params = zstd_params(); |
156 | |
157 | out_len = zstd_compress_cctx(cctx: zctx->cctx, dst, dst_capacity: *dlen, src, src_size: slen, parameters: ¶ms); |
158 | if (zstd_is_error(code: out_len)) |
159 | return -EINVAL; |
160 | *dlen = out_len; |
161 | return 0; |
162 | } |
163 | |
164 | static int zstd_compress(struct crypto_tfm *tfm, const u8 *src, |
165 | unsigned int slen, u8 *dst, unsigned int *dlen) |
166 | { |
167 | struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
168 | |
169 | return __zstd_compress(src, slen, dst, dlen, ctx); |
170 | } |
171 | |
172 | static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src, |
173 | unsigned int slen, u8 *dst, unsigned int *dlen, |
174 | void *ctx) |
175 | { |
176 | return __zstd_compress(src, slen, dst, dlen, ctx); |
177 | } |
178 | |
179 | static int __zstd_decompress(const u8 *src, unsigned int slen, |
180 | u8 *dst, unsigned int *dlen, void *ctx) |
181 | { |
182 | size_t out_len; |
183 | struct zstd_ctx *zctx = ctx; |
184 | |
185 | out_len = zstd_decompress_dctx(dctx: zctx->dctx, dst, dst_capacity: *dlen, src, src_size: slen); |
186 | if (zstd_is_error(code: out_len)) |
187 | return -EINVAL; |
188 | *dlen = out_len; |
189 | return 0; |
190 | } |
191 | |
192 | static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src, |
193 | unsigned int slen, u8 *dst, unsigned int *dlen) |
194 | { |
195 | struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
196 | |
197 | return __zstd_decompress(src, slen, dst, dlen, ctx); |
198 | } |
199 | |
200 | static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src, |
201 | unsigned int slen, u8 *dst, unsigned int *dlen, |
202 | void *ctx) |
203 | { |
204 | return __zstd_decompress(src, slen, dst, dlen, ctx); |
205 | } |
206 | |
207 | static struct crypto_alg alg = { |
208 | .cra_name = "zstd" , |
209 | .cra_driver_name = "zstd-generic" , |
210 | .cra_flags = CRYPTO_ALG_TYPE_COMPRESS, |
211 | .cra_ctxsize = sizeof(struct zstd_ctx), |
212 | .cra_module = THIS_MODULE, |
213 | .cra_init = zstd_init, |
214 | .cra_exit = zstd_exit, |
215 | .cra_u = { .compress = { |
216 | .coa_compress = zstd_compress, |
217 | .coa_decompress = zstd_decompress } } |
218 | }; |
219 | |
220 | static struct scomp_alg scomp = { |
221 | .alloc_ctx = zstd_alloc_ctx, |
222 | .free_ctx = zstd_free_ctx, |
223 | .compress = zstd_scompress, |
224 | .decompress = zstd_sdecompress, |
225 | .base = { |
226 | .cra_name = "zstd" , |
227 | .cra_driver_name = "zstd-scomp" , |
228 | .cra_module = THIS_MODULE, |
229 | } |
230 | }; |
231 | |
232 | static int __init zstd_mod_init(void) |
233 | { |
234 | int ret; |
235 | |
236 | ret = crypto_register_alg(alg: &alg); |
237 | if (ret) |
238 | return ret; |
239 | |
240 | ret = crypto_register_scomp(alg: &scomp); |
241 | if (ret) |
242 | crypto_unregister_alg(alg: &alg); |
243 | |
244 | return ret; |
245 | } |
246 | |
247 | static void __exit zstd_mod_fini(void) |
248 | { |
249 | crypto_unregister_alg(alg: &alg); |
250 | crypto_unregister_scomp(alg: &scomp); |
251 | } |
252 | |
253 | subsys_initcall(zstd_mod_init); |
254 | module_exit(zstd_mod_fini); |
255 | |
256 | MODULE_LICENSE("GPL" ); |
257 | MODULE_DESCRIPTION("Zstd Compression Algorithm" ); |
258 | MODULE_ALIAS_CRYPTO("zstd" ); |
259 | |