1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2024, SUSE LLC
4 *
5 * Authors: Enzo Matsumiya <ematsumiya@suse.de>
6 *
7 * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec.
8 */
9#include <linux/slab.h>
10#include <linux/sizes.h>
11#include <linux/count_zeros.h>
12#include <linux/unaligned.h>
13
14#include "lz77.h"
15
16/*
17 * Compression parameters.
18 */
19#define LZ77_MATCH_MIN_LEN 4
20#define LZ77_MATCH_MIN_DIST 1
21#define LZ77_MATCH_MAX_DIST SZ_1K
22#define LZ77_HASH_LOG 15
23#define LZ77_HASH_SIZE (1 << LZ77_HASH_LOG)
24#define LZ77_STEP_SIZE sizeof(u64)
25
26static __always_inline u8 lz77_read8(const u8 *ptr)
27{
28 return get_unaligned(ptr);
29}
30
31static __always_inline u64 lz77_read64(const u64 *ptr)
32{
33 return get_unaligned(ptr);
34}
35
36static __always_inline void lz77_write8(u8 *ptr, u8 v)
37{
38 put_unaligned(v, ptr);
39}
40
41static __always_inline void lz77_write16(u16 *ptr, u16 v)
42{
43 put_unaligned_le16(val: v, p: ptr);
44}
45
46static __always_inline void lz77_write32(u32 *ptr, u32 v)
47{
48 put_unaligned_le32(val: v, p: ptr);
49}
50
51static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end)
52{
53 const void *start = cur;
54 u64 diff;
55
56 /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */
57 do {
58 diff = lz77_read64(ptr: cur) ^ lz77_read64(ptr: wnd);
59 if (!diff) {
60 cur += LZ77_STEP_SIZE;
61 wnd += LZ77_STEP_SIZE;
62
63 continue;
64 }
65
66 /* This computes the number of common bytes in @diff. */
67 cur += count_trailing_zeros(x: diff) >> 3;
68
69 return (cur - start);
70 } while (likely(cur + LZ77_STEP_SIZE < end));
71
72 while (cur < end && lz77_read8(ptr: cur++) == lz77_read8(ptr: wnd++))
73 ;
74
75 return (cur - start);
76}
77
78static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u32 len)
79{
80 len -= 3;
81 dist--;
82 dist <<= 3;
83
84 if (len < 7) {
85 lz77_write16(ptr: dst, v: dist + len);
86
87 return dst + 2;
88 }
89
90 dist |= 7;
91 lz77_write16(ptr: dst, v: dist);
92 dst += 2;
93 len -= 7;
94
95 if (!*nib) {
96 lz77_write8(ptr: dst, umin(len, 15));
97 *nib = dst;
98 dst++;
99 } else {
100 u8 *b = *nib;
101
102 lz77_write8(ptr: b, v: *b | umin(len, 15) << 4);
103 *nib = NULL;
104 }
105
106 if (len < 15)
107 return dst;
108
109 len -= 15;
110 if (len < 255) {
111 lz77_write8(ptr: dst, v: len);
112
113 return dst + 1;
114 }
115
116 lz77_write8(ptr: dst, v: 0xff);
117 dst++;
118 len += 7 + 15;
119 if (len <= 0xffff) {
120 lz77_write16(ptr: dst, v: len);
121
122 return dst + 2;
123 }
124
125 lz77_write16(ptr: dst, v: 0);
126 dst += 2;
127 lz77_write32(ptr: dst, v: len);
128
129 return dst + 4;
130}
131
132noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
133{
134 const void *srcp, *end;
135 void *dstp, *nib, *flag_pos;
136 u32 flag_count = 0;
137 long flag = 0;
138 u64 *htable;
139
140 srcp = src;
141 end = src + slen;
142 dstp = dst;
143 nib = NULL;
144 flag_pos = dstp;
145 dstp += 4;
146
147 htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
148 if (!htable)
149 return -ENOMEM;
150
151 /* Main loop. */
152 do {
153 u32 dist, len = 0;
154 const void *wnd;
155 u64 hash;
156
157 hash = ((lz77_read64(ptr: srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
158 wnd = src + htable[hash];
159 htable[hash] = srcp - src;
160 dist = srcp - wnd;
161
162 if (dist && dist < LZ77_MATCH_MAX_DIST)
163 len = lz77_match_len(wnd, cur: srcp, end);
164
165 if (len < LZ77_MATCH_MIN_LEN) {
166 lz77_write8(ptr: dstp, v: lz77_read8(ptr: srcp));
167
168 dstp++;
169 srcp++;
170
171 flag <<= 1;
172 flag_count++;
173 if (flag_count == 32) {
174 lz77_write32(ptr: flag_pos, v: flag);
175 flag_count = 0;
176 flag_pos = dstp;
177 dstp += 4;
178 }
179
180 continue;
181 }
182
183 /*
184 * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
185 * going further.
186 */
187 if (unlikely(dstp - dst >= slen - (slen >> 3))) {
188 *dlen = slen;
189 goto out;
190 }
191
192 dstp = lz77_write_match(dst: dstp, nib: &nib, dist, len);
193 srcp += len;
194
195 flag = (flag << 1) | 1;
196 flag_count++;
197 if (flag_count == 32) {
198 lz77_write32(ptr: flag_pos, v: flag);
199 flag_count = 0;
200 flag_pos = dstp;
201 dstp += 4;
202 }
203 } while (likely(srcp + LZ77_STEP_SIZE < end));
204
205 while (srcp < end) {
206 u32 c = umin(end - srcp, 32 - flag_count);
207
208 memcpy(dstp, srcp, c);
209
210 dstp += c;
211 srcp += c;
212
213 flag <<= c;
214 flag_count += c;
215 if (flag_count == 32) {
216 lz77_write32(ptr: flag_pos, v: flag);
217 flag_count = 0;
218 flag_pos = dstp;
219 dstp += 4;
220 }
221 }
222
223 flag <<= (32 - flag_count);
224 flag |= (1 << (32 - flag_count)) - 1;
225 lz77_write32(ptr: flag_pos, v: flag);
226
227 *dlen = dstp - dst;
228out:
229 kvfree(addr: htable);
230
231 if (*dlen < slen)
232 return 0;
233
234 return -EMSGSIZE;
235}
236

source code of linux/fs/smb/client/compress/lz77.c