-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinflate.luau
388 lines (326 loc) · 9.91 KB
/
inflate.luau
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
-- Tree class for storing Huffman trees used in DEFLATE decompression
local Tree = {}
export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree }))
type TreeInner = {
table: { number }, -- Length of 16, stores code length counts
trans: { number }, -- Length of 288, stores code to symbol translations
}
--- Creates a new Tree instance with initialized tables
function Tree.new(): Tree
return setmetatable(
{
table = table.create(16, 0),
trans = table.create(288, 0),
} :: TreeInner,
{ __index = Tree }
)
end
-- Data class for managing compression state and buffers
local Data = {}
export type Data = typeof(setmetatable({} :: DataInner, { __index = Data }))
-- stylua: ignore
export type DataInner = {
source: buffer, -- Input buffer containing compressed data
sourceIndex: number, -- Current position in source buffer
tag: number, -- Bit buffer for reading compressed data
bitcount: number, -- Number of valid bits in tag
dest: buffer, -- Output buffer for decompressed data
destLen: number, -- Current length of decompressed data
ltree: Tree, -- Length/literal tree for current block
dtree: Tree, -- Distance tree for current block
}
--- Creates a new Data instance with initialized compression state
function Data.new(source: buffer, dest: buffer): Data
return setmetatable(
{
source = source,
sourceIndex = 0,
tag = 0,
bitcount = 0,
dest = dest,
destLen = 0,
ltree = Tree.new(),
dtree = Tree.new(),
} :: DataInner,
{ __index = Data }
)
end
-- Static Huffman trees used for fixed block types
local staticLengthTree = Tree.new()
local staticDistTree = Tree.new()
-- Tables for storing extra bits and base values for length/distance codes
local lengthBits = table.create(30, 0)
local lengthBase = table.create(30, 0)
local distBits = table.create(30, 0)
local distBase = table.create(30, 0)
-- Special ordering of code length codes used in dynamic Huffman trees
-- stylua: ignore
local clcIndex = {
16, 17, 18, 0, 8, 7, 9, 6,
10, 5, 11, 4, 12, 3, 13, 2,
14, 1, 15
}
-- Tree used for decoding code lengths in dynamic blocks
local codeTree = Tree.new()
local lengths = table.create(288 + 32, 0)
--- Builds the extra bits and base tables for length and distance codes
local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number)
local sum = first
-- Initialize the bits table with appropriate bit lengths
for i = 0, delta - 1 do
bits[i] = 0
end
for i = 0, 29 - delta do
bits[i + delta] = math.floor(i / delta)
end
-- Build the base value table using bit lengths
for i = 0, 29 do
base[i] = sum
sum += bit32.lshift(1, bits[i])
end
end
--- Constructs the fixed Huffman trees used in DEFLATE format
local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
-- Build the fixed length tree according to DEFLATE specification
for i = 0, 6 do
lengthTree.table[i] = 0
end
lengthTree.table[7] = 24
lengthTree.table[8] = 152
lengthTree.table[9] = 112
-- Populate the translation table for length codes
for i = 0, 23 do
lengthTree.trans[i] = 256 + i
end
for i = 0, 143 do
lengthTree.trans[24 + i] = i
end
for i = 0, 7 do
lengthTree.trans[24 + 144 + i] = 280 + i
end
for i = 0, 111 do
lengthTree.trans[24 + 144 + 8 + i] = 144 + i
end
-- Build the fixed distance tree (simpler than length tree)
for i = 0, 4 do
distTree.table[i] = 0
end
distTree.table[5] = 32
for i = 0, 31 do
distTree.trans[i] = i
end
end
--- Temporary array for building trees
local offs = table.create(16, 0)
--- Builds a Huffman tree from a list of code lengths
local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
-- Initialize the code length count table
for i = 0, 15 do
t.table[i] = 0
end
-- Count the frequency of each code length
for i = 0, num - 1 do
t.table[lengths[off + i]] += 1
end
t.table[0] = 0
-- Calculate offsets for distribution sort
local sum = 0
for i = 0, 15 do
offs[i] = sum
sum += t.table[i]
end
-- Create the translation table mapping codes to symbols
for i = 0, num - 1 do
local len = lengths[off + i]
if len > 0 then
t.trans[offs[len]] = i
offs[len] += 1
end
end
end
--- Reads a single bit from the input stream
local function getBit(d: Data): number
if d.bitcount <= 0 then
d.tag = buffer.readu8(d.source, d.sourceIndex)
d.sourceIndex += 1
d.bitcount = 8
end
local bit = bit32.band(d.tag, 1)
d.tag = bit32.rshift(d.tag, 1)
d.bitcount -= 1
return bit
end
--- Reads multiple bits from the input stream with a base value
local function readBits(d: Data, num: number?, base: number): number
if not num then
return base
end
-- Ensure we have enough bits in the buffer
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
d.sourceIndex += 1
d.bitcount += 8
end
local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num))
d.tag = bit32.rshift(d.tag, num)
d.bitcount -= num
return val + base
end
--- Decodes a symbol using a Huffman tree
local function decodeSymbol(d: Data, t: Tree): number
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
d.sourceIndex += 1
d.bitcount += 8
end
local sum, cur, len = 0, 0, 0
local tag = d.tag
-- Traverse the Huffman tree to find the symbol
repeat
cur = 2 * cur + bit32.band(tag, 1)
tag = bit32.rshift(tag, 1)
len += 1
sum += t.table[len]
cur -= t.table[len]
until cur < 0
d.tag = tag
d.bitcount -= len
return t.trans[sum + cur]
end
--- Decodes the dynamic Huffman trees for a block
local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree)
local hlit = readBits(d, 5, 257) -- Number of literal/length codes
local hdist = readBits(d, 5, 1) -- Number of distance codes
local hclen = readBits(d, 4, 4) -- Number of code length codes
-- Initialize code lengths array
for i = 0, 18 do
lengths[i] = 0
end
-- Read code lengths for the code length alphabet
for i = 0, hclen - 1 do
lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
end
-- Build the code lengths tree
buildTree(codeTree, lengths, 0, 19)
-- Decode length/distance tree code lengths
local num = 0
while num < hlit + hdist do
local sym = decodeSymbol(d, codeTree)
if sym == 16 then
-- Copy previous code length 3-6 times
local prev = lengths[num - 1]
for _ = 1, readBits(d, 2, 3) do
lengths[num] = prev
num += 1
end
elseif sym == 17 then
-- Repeat zero 3-10 times
for _ = 1, readBits(d, 3, 3) do
lengths[num] = 0
num += 1
end
elseif sym == 18 then
-- Repeat zero 11-138 times
for _ = 1, readBits(d, 7, 11) do
lengths[num] = 0
num += 1
end
else
-- Regular code length 0-15
lengths[num] = sym
num += 1
end
end
-- Build the literal/length and distance trees
buildTree(lengthTree, lengths, 0, hlit)
buildTree(distTree, lengths, hlit, hdist)
end
--- Inflates a block of data using Huffman trees
local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
while true do
local sym = decodeSymbol(d, lengthTree)
if sym == 256 then
-- End of block
return
end
if sym < 256 then
-- Literal byte
buffer.writeu8(d.dest, d.destLen, sym)
d.destLen += 1
else
-- Length/distance pair for copying
sym -= 257
local length = readBits(d, lengthBits[sym], lengthBase[sym])
local dist = decodeSymbol(d, distTree)
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
-- Copy bytes from back reference
for i = offs, offs + length - 1 do
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
d.destLen += 1
end
end
end
end
--- Processes an uncompressed block
local function inflateUncompressedBlock(d: Data)
-- Align to byte boundary
local bytesToMove = d.bitcount // 8
d.sourceIndex -= bytesToMove
d.bitcount = 0
d.tag = 0
-- Read block length and its complement
local length = buffer.readu8(d.source, d.sourceIndex + 1)
length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
-- Verify block length using ones complement
if length ~= bit32.bxor(invlength, 0xffff) then
error("Invalid block length")
end
d.sourceIndex += 4
-- Copy uncompressed data to output
for _ = 1, length do
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
d.destLen += 1
d.sourceIndex += 1
end
d.bitcount = 0
end
--- Main decompression function that processes DEFLATE compressed data
local function uncompress(source: buffer, uncompressedSize: number?): buffer
local dest = buffer.create(
-- If the uncompressed size is known, we use that, otherwise we use a default
-- size that is a 7 times more than the compressed size; this factor works
-- well for most cases other than those with a very high compression ratio
uncompressedSize or buffer.len(source) * 7
)
local d = Data.new(source, dest)
repeat
local bfinal = getBit(d) -- Last block flag
local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic)
if btype == 0 then
inflateUncompressedBlock(d)
elseif btype == 1 then
inflateBlockData(d, staticLengthTree, staticDistTree)
elseif btype == 2 then
decodeTrees(d, d.ltree, d.dtree)
inflateBlockData(d, d.ltree, d.dtree)
else
error("Invalid block type")
end
until bfinal == 1
-- Trim output buffer to actual size if needed
if d.destLen < buffer.len(dest) then
local result = buffer.create(d.destLen)
buffer.copy(result, 0, dest, 0, d.destLen)
return result
end
return dest
end
-- Initialize static trees and lookup tables for DEFLATE format
buildFixedTrees(staticLengthTree, staticDistTree)
buildBitsBase(lengthBits, lengthBase, 4, 3)
buildBitsBase(distBits, distBase, 2, 1)
lengthBits[28] = 0
lengthBase[28] = 258
return uncompress