huffman.zig (32753B)
1 //! This modules provides function for huffman encoding and decoding, specific 2 //! to the PIZ compression format. 3 4 const std = @import("std"); 5 const assert = std.debug.assert; 6 const expect = std.testing.expect; 7 8 const encoding_bits = 16; 9 const decoding_bits = 14; 10 11 const encoding_table_size = (1 << encoding_bits) + 1; 12 const decoding_table_size = 1 << decoding_bits; 13 const decoding_mask = decoding_table_size - 1; 14 15 const Encoding = packed struct { 16 val: u58, 17 len: u6, 18 }; 19 20 const DecodingType = enum { empty, short, long }; 21 const Decoding = union(DecodingType) { 22 empty: void, 23 short: ShortDecoding, 24 long: std.ArrayList(u32), 25 }; 26 27 const ShortDecoding = packed struct { 28 val: u16, 29 len: u6, 30 _: u10 = 0, 31 }; 32 33 fn countFrequencies(data: []const u16, allocator: std.mem.Allocator) ![]u64 { 34 const frequencies = try allocator.alloc(u64, encoding_table_size); 35 @memset(frequencies, 0); 36 for (data) |value| { 37 frequencies[value] += 1; 38 } 39 return frequencies; 40 } 41 42 test "count frequencies" { 43 const allocator = std.testing.allocator; 44 const data = [_]u16{ 3, 3, 3, 10, 5, 5, 1, 1, 1, 1 }; 45 const frequencies = try countFrequencies(&data, allocator); 46 defer allocator.free(frequencies); 47 try expect(frequencies[1] == 4); 48 try expect(frequencies[3] == 3); 49 try expect(frequencies[5] == 2); 50 try expect(frequencies[10] == 1); 51 } 52 53 const Frequency = struct { 54 frequency: u64, 55 index: u32, 56 57 fn compare(_: void, a: Frequency, b: Frequency) std.math.Order { 58 const freq_order = std.math.order(a.frequency, b.frequency); 59 switch (freq_order) { 60 .eq => return std.math.order(a.index, b.index), 61 else => return freq_order, 62 } 63 } 64 }; 65 const FrequencyHeap = std.PriorityQueue(Frequency, void, Frequency.compare); 66 67 const EncodingTableResult = struct { 68 encoding_table: []Encoding, 69 min_code_idx: u32, 70 max_code_idx: u32, 71 }; 72 73 fn buildEncodingTable( 74 frequencies: []const u64, 75 allocator: std.mem.Allocator, 76 ) !EncodingTableResult { 77 assert(frequencies.len == encoding_table_size); 78 // Initialize a heap of frequencies 79 var heap = FrequencyHeap.init(allocator, {}); 80 defer heap.deinit(); 81 const link = try allocator.alloc(u32, encoding_table_size); 82 defer allocator.free(link); 83 var min_non_zero: u32 = encoding_table_size - 1; 84 for (0..encoding_table_size) |i| { 85 if (frequencies[i] == 0) continue; 86 min_non_zero = @intCast(i); 87 break; 88 } 89 var max_non_zero: u32 = 0; 90 for (min_non_zero..encoding_table_size) |i| { 91 link[i] = @intCast(i); 92 if (frequencies[i] == 0) continue; 93 try heap.add(.{ .frequency = frequencies[i], .index = @intCast(i) }); 94 max_non_zero = @intCast(i); 95 } 96 // Add a pseudo-symbol which will indicate run-length encoding 97 try heap.add(.{ .frequency = 1, .index = max_non_zero + 1 }); 98 max_non_zero += 1; 99 // Compute code lengths for each symbol 100 const encoding_table = try allocator.alloc(Encoding, encoding_table_size); 101 @memset(encoding_table, .{ .val = 0, .len = 0 }); 102 while (heap.count() > 1) { 103 const lowest = heap.remove(); 104 const second_lowest = heap.remove(); 105 try heap.add(.{ 106 .frequency = lowest.frequency + second_lowest.frequency, 107 .index = second_lowest.index, 108 }); 109 var index = second_lowest.index; 110 while (true) : (index = link[index]) { 111 encoding_table[index].len += 1; 112 assert(encoding_table[index].len < 59); 113 if (index == link[index]) break; 114 } 115 link[index] = lowest.index; 116 index = lowest.index; 117 while (true) : (index = link[index]) { 118 encoding_table[index].len += 1; 119 assert(encoding_table[index].len < 59); 120 if (index == link[index]) break; 121 } 122 } 123 return .{ 124 .encoding_table = encoding_table, 125 .min_code_idx = min_non_zero, 126 .max_code_idx = max_non_zero, 127 }; 128 } 129 130 test "encoding lengths" { 131 const allocator = std.testing.allocator; 132 const data = [_]u16{ 133 1, 1, 1, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 134 5, 5, 6, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 135 }; 136 const frequencies = try countFrequencies(&data, allocator); 137 defer allocator.free(frequencies); 138 const result = try buildEncodingTable(frequencies, allocator); 139 defer allocator.free(result.encoding_table); 140 const encoding_table = result.encoding_table; 141 const min_code_idx = result.min_code_idx; 142 const max_code_idx = result.max_code_idx; 143 // Only the symbols in our data should have code lengths. 144 for (encoding_table[0..min_code_idx]) |encoding| { 145 try expect(encoding.len == 0); 146 } 147 for (encoding_table[min_code_idx .. max_code_idx + 1]) |encoding| { 148 try expect(encoding.len > 0); 149 } 150 for (encoding_table[max_code_idx + 1 ..]) |encoding| { 151 try expect(encoding.len == 0); 152 } 153 // More frequent symbols shouldn't have longer codes. 154 for (min_code_idx..max_code_idx + 1) |i| { 155 for (min_code_idx..max_code_idx + 1) |j| { 156 if (frequencies[i] <= frequencies[j]) continue; 157 try expect(encoding_table[i].len <= encoding_table[j].len); 158 } 159 } 160 } 161 162 fn buildCanonicalEncodings(encoding_table: []Encoding) void { 163 // This algorithm originates here: http://www.compressconsult.com/huffman/#huffman 164 // How many codes of each length are there? 165 var num_codes_per_length: [59]u32 = undefined; 166 @memset(&num_codes_per_length, 0); 167 for (encoding_table) |encoding| { 168 num_codes_per_length[encoding.len] += 1; 169 } 170 // What is the lowest code for each length? 171 var code: u58 = 0; 172 var codes_per_length: [59]u58 = undefined; 173 for (0..59) |i| { 174 const j = 59 - (i + 1); 175 codes_per_length[j] = code; 176 code = (code + num_codes_per_length[j]) >> 1; 177 } 178 std.debug.assert(codes_per_length[0] == 1); 179 // Assign lowest code, then increment. 180 for (encoding_table) |*encoding| { 181 if (encoding.len == 0) continue; 182 encoding.val = codes_per_length[encoding.len]; 183 codes_per_length[encoding.len] += 1; 184 std.debug.assert(encoding.val >> encoding.len == 0); 185 } 186 } 187 188 test "canonical encodings" { 189 const allocator = std.testing.allocator; 190 const data = [_]u16{ 191 1, 1, 1, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 192 5, 5, 6, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 193 }; 194 const frequencies = try countFrequencies(&data, allocator); 195 defer allocator.free(frequencies); 196 const result = try buildEncodingTable(frequencies, allocator); 197 defer allocator.free(result.encoding_table); 198 const encoding_table = result.encoding_table; 199 const min_code_idx = result.min_code_idx; 200 const max_code_idx = result.max_code_idx; 201 buildCanonicalEncodings(encoding_table); 202 // Only the symbols in our data should have code lengths and values. 203 // Code values should match the specified length. 204 for (encoding_table[0..min_code_idx]) |encoding| { 205 try expect(encoding.val == 0); 206 try expect(encoding.len == 0); 207 } 208 for (encoding_table[min_code_idx .. max_code_idx + 1]) |encoding| { 209 try expect(encoding.val >= 0); 210 try expect(encoding.len > 0); 211 try expect(encoding.val >> encoding.len == 0); 212 } 213 for (encoding_table[max_code_idx + 1 ..]) |encoding| { 214 try expect(encoding.val == 0); 215 try expect(encoding.len == 0); 216 } 217 } 218 219 fn encode( 220 data: []const u16, 221 encoding_table: []const Encoding, 222 rle_symbol: u32, 223 writer: anytype, 224 ) !u32 { 225 var num_bits: u32 = 0; 226 var bit_writer = std.io.bitWriter(std.builtin.Endian.big, writer); 227 // Encode and write all symbols. 228 const rle_encoding = encoding_table[rle_symbol]; 229 var run_length: u8 = 0; 230 var run_symbol = data[0]; 231 for (data[1..]) |x| { 232 if (x == run_symbol and run_length < 255) { 233 run_length += 1; 234 } else { 235 const run_encoding = encoding_table[run_symbol]; 236 num_bits += try writeCode( 237 run_encoding, 238 rle_encoding, 239 run_length, 240 &bit_writer, 241 ); 242 run_length = 0; 243 } 244 run_symbol = x; 245 } 246 const run_encoding = encoding_table[run_symbol]; 247 num_bits += try writeCode( 248 run_encoding, 249 rle_encoding, 250 run_length, 251 &bit_writer, 252 ); 253 try bit_writer.flushBits(); 254 return num_bits; 255 } 256 257 fn writeCode( 258 encoding: Encoding, 259 rle_encoding: Encoding, 260 run_length: u8, 261 bit_writer: anytype, 262 ) !u32 { 263 var num_bits: u32 = 0; 264 // Output `run_length` instances of `encoding`. This is either stored as 265 // `encoding, rle_encoding, run_length` or `[encoding] * run_length`, 266 // whichever is shortest. 267 const rle_length = encoding.len + rle_encoding.len + 8; 268 const rep_length = encoding.len * run_length; 269 if (rle_length < rep_length) { 270 try bit_writer.writeBits(encoding.val, encoding.len); 271 try bit_writer.writeBits(rle_encoding.val, rle_encoding.len); 272 try bit_writer.writeBits(run_length, 8); 273 num_bits += encoding.len + rle_encoding.len + 8; 274 } else { 275 for (0..run_length + 1) |_| { 276 try bit_writer.writeBits(encoding.val, encoding.len); 277 num_bits += encoding.len; 278 } 279 } 280 return num_bits; 281 } 282 283 test "encode" { 284 // Adapted from tests in the 'exrs' project: https://github.com/johannesvollmer/exrs 285 const allocator = std.testing.allocator; 286 const uncompressed = [_]u16{ 287 3852, 2432, 33635, 49381, 10100, 15095, 62693, 63738, 62359, 5013, 288 7715, 59875, 28182, 34449, 19983, 20399, 63407, 29486, 4877, 26738, 289 44815, 14042, 46091, 48228, 25682, 35412, 7582, 65069, 6632, 54124, 290 13798, 27503, 52154, 61961, 30474, 46880, 39097, 15754, 52897, 42371, 291 54053, 14178, 48276, 34591, 42602, 32126, 42062, 31474, 16274, 55991, 292 2882, 17039, 56389, 20835, 57057, 54081, 3414, 33957, 52584, 10222, 293 25139, 40002, 44980, 1602, 48021, 19703, 6562, 61777, 41582, 201, 294 31253, 51790, 15888, 40921, 3627, 12184, 16036, 26349, 3159, 29002, 295 14535, 50632, 18118, 33583, 18878, 59470, 32835, 9347, 16991, 21303, 296 26263, 8312, 14017, 41777, 43240, 3500, 60250, 52437, 45715, 61520, 297 }; 298 const compressed = [_]u8{ 299 0x10, 0x9, 0xb4, 0xe4, 0x4c, 0xf7, 0xef, 0x42, 0x87, 0x6a, 0xb5, 0xc2, 300 0x34, 0x9e, 0x2f, 0x12, 0xae, 0x21, 0x68, 0xf2, 0xa8, 0x74, 0x37, 0xe1, 301 0x98, 0x14, 0x59, 0x57, 0x2c, 0x24, 0x3b, 0x35, 0x6c, 0x1b, 0x8b, 0xcc, 302 0xe6, 0x13, 0x38, 0xc, 0x8e, 0xe2, 0xc, 0xfe, 0x49, 0x73, 0xbc, 0x2b, 303 0x7b, 0x9, 0x27, 0x79, 0x14, 0xc, 0x94, 0x42, 0xf8, 0x7c, 0x1, 0x8d, 304 0x26, 0xde, 0x87, 0x26, 0x71, 0x50, 0x45, 0xc6, 0x28, 0x40, 0xd5, 0xe, 305 0x8d, 0x8, 0x1e, 0x4c, 0xa4, 0x79, 0x57, 0xf0, 0xc3, 0x6d, 0x5c, 0x6d, 306 0xc0, 307 }; 308 // Build encoding table for data. 309 const frequencies = try countFrequencies(&uncompressed, allocator); 310 defer allocator.free(frequencies); 311 const result = try buildEncodingTable(frequencies, allocator); 312 const encoding_table = result.encoding_table; 313 const max_code_idx = result.max_code_idx; 314 defer allocator.free(encoding_table); 315 buildCanonicalEncodings(encoding_table); 316 // Encode data into output buffer. 317 var encoded = std.ArrayList(u8).init(allocator); 318 defer encoded.deinit(); 319 const writer = encoded.writer(); 320 _ = try encode(&uncompressed, encoding_table, max_code_idx, writer); 321 try expect(std.mem.eql(u8, encoded.items, &compressed)); 322 } 323 324 fn buildDecodingTable( 325 encoding_table: []const Encoding, 326 min_code_idx: u32, 327 max_code_idx: u32, 328 allocator: std.mem.Allocator, 329 ) ![]Decoding { 330 var decoding_table = try allocator.alloc(Decoding, decoding_table_size); 331 @memset(decoding_table, Decoding{ .empty = @as(void, {}) }); 332 for (min_code_idx..max_code_idx + 1) |encoding_idx| { 333 const encoding = encoding_table[encoding_idx]; 334 if (encoding.val >> encoding.len != 0) { 335 return error.InvalidEncodingTableEntry; 336 } 337 if (encoding.len > decoding_bits) { 338 // If symbol `x` is encoded as `XX111111111111` and symbol `y` is 339 // encoded as `YY111111111111`, then index `11111111111111` in the 340 // decoding table will contain the list `[x, y]`. 341 const shift = encoding.len - decoding_bits; 342 const decoding_idx = encoding.val >> shift; 343 const decoding = &decoding_table[decoding_idx]; 344 switch (decoding.*) { 345 .empty => { 346 decoding.* = Decoding{ 347 .long = std.ArrayList(u32).init(allocator), 348 }; 349 try decoding.long.append(@intCast(encoding_idx)); 350 }, 351 .long => try decoding.long.append(@intCast(encoding_idx)), 352 .short => return error.InvalidEncodingTableEntry, 353 } 354 } else if (encoding.len > 0) { 355 // If symbol `x` is encoded as `XXX`, then all entries between 356 // `XXX00000000000` and `XXX11111111111` map to the value `x` in 357 // the decoding table. 358 const shift = decoding_bits - encoding.len; 359 const start: u64 = @as(u64, encoding.val) << shift; 360 const count: u64 = @as(u64, 1) << shift; 361 for (start..start + count) |i| { 362 decoding_table[i] = Decoding{ .short = .{ 363 .len = encoding.len, 364 .val = @intCast(encoding_idx), 365 } }; 366 } 367 } 368 } 369 return decoding_table; 370 } 371 372 test "build decoding table" { 373 // Adapted from tests in the 'exrs' project: https://github.com/johannesvollmer/exrs 374 const allocator = std.testing.allocator; 375 const uncompressed = [_]u16{ 376 3852, 2432, 33635, 49381, 10100, 15095, 62693, 63738, 62359, 5013, 377 7715, 59875, 28182, 34449, 19983, 20399, 63407, 29486, 4877, 26738, 378 44815, 14042, 46091, 48228, 25682, 35412, 7582, 65069, 6632, 54124, 379 13798, 27503, 52154, 61961, 30474, 46880, 39097, 15754, 52897, 42371, 380 54053, 14178, 48276, 34591, 42602, 32126, 42062, 31474, 16274, 55991, 381 2882, 17039, 56389, 20835, 57057, 54081, 3414, 33957, 52584, 10222, 382 25139, 40002, 44980, 1602, 48021, 19703, 6562, 61777, 41582, 201, 383 31253, 51790, 15888, 40921, 3627, 12184, 16036, 26349, 3159, 29002, 384 14535, 50632, 18118, 33583, 18878, 59470, 32835, 9347, 16991, 21303, 385 26263, 8312, 14017, 41777, 43240, 3500, 60250, 52437, 45715, 61520, 386 }; 387 // Build encoding and decoding tables. 388 const frequencies = try countFrequencies(&uncompressed, allocator); 389 defer allocator.free(frequencies); 390 const result = try buildEncodingTable(frequencies, allocator); 391 defer allocator.free(result.encoding_table); 392 const encoding_table = result.encoding_table; 393 const min_code_idx = result.min_code_idx; 394 const max_code_idx = result.max_code_idx; 395 buildCanonicalEncodings(encoding_table); 396 const decoding_table = try buildDecodingTable( 397 encoding_table, 398 min_code_idx, 399 max_code_idx, 400 allocator, 401 ); 402 defer allocator.free(decoding_table); 403 // Doesn't test long codes, but at least we know short codes work. 404 for (uncompressed) |x| { 405 const encoding = encoding_table[x]; 406 const decoding_idx = @as(u64, encoding.val) << (decoding_bits - encoding.len); 407 const decoding = decoding_table[decoding_idx]; 408 try expect(decoding.short.val == x); 409 } 410 } 411 412 const BitBufferError = error{ 413 BitOverflow, 414 BitUnderflow, 415 }; 416 417 const BitBuffer = struct { 418 bits: u64 = 0, 419 num_bits: u6 = 0, 420 421 fn addBits(self: *@This(), bits: anytype, num_bits: u6) !void { 422 if (@as(u8, 63) - num_bits < self.num_bits) { 423 return error.BitOverflow; 424 } 425 self.num_bits += num_bits; 426 const mask = (@as(u64, 1) << num_bits) - 1; 427 self.bits = (self.bits << num_bits) | (bits & mask); 428 } 429 430 fn peekBits(self: *@This(), comptime T: type, num_bits: u6) !T { 431 if (num_bits > self.num_bits) { 432 return error.BitUnderflow; 433 } 434 const mask = (@as(u64, 1) << num_bits) - 1; 435 const val = (self.bits >> (self.num_bits - num_bits)) & mask; 436 return @as(T, @intCast(val)); 437 } 438 439 fn readBits(self: *@This(), comptime T: type, num_bits: u6) !T { 440 const val = try self.peekBits(T, num_bits); 441 self.num_bits -= num_bits; 442 return val; 443 } 444 445 fn dumpBits(self: *@This(), num_bits: u6) !void { 446 if (num_bits > self.num_bits) { 447 return error.BitUnderflow; 448 } 449 self.bits >>= @intCast(num_bits); 450 self.num_bits -= @intCast(num_bits); 451 } 452 }; 453 454 fn decode( 455 encoding_table: []const Encoding, 456 decoding_table: []const Decoding, 457 reader: anytype, 458 writer: anytype, 459 num_bits: u32, 460 ) !void { 461 // Initialize buffer with input data. 462 var bit_buffer = BitBuffer{}; 463 while (bit_buffer.num_bits < decoding_bits) { 464 const byte = try reader.readByte(); 465 try bit_buffer.addBits(byte, 8); 466 } 467 // Decode (most) input data. Some data may still be in the buffer after 468 // this loop. 469 while (bit_buffer.num_bits >= decoding_bits) { 470 const decoding_idx = try bit_buffer.peekBits(u16, decoding_bits); 471 const decoding = decoding_table[decoding_idx]; 472 switch (decoding) { 473 .short => { 474 bit_buffer.num_bits -= decoding.short.len; 475 try writer.writeInt( 476 u16, 477 decoding.short.val, 478 std.builtin.Endian.little, 479 ); 480 }, 481 .long => { 482 // For long codes, we have to search the list stored in the 483 // decoding table for a matching entry. 484 var found = false; 485 for (decoding.long.items) |encoding_idx| { 486 const encoding = encoding_table[encoding_idx]; 487 const code = try bit_buffer.peekBits(u16, encoding.len); 488 if (encoding.val == code) { 489 bit_buffer.num_bits -= encoding.len; 490 try writer.writeInt( 491 u16, 492 decoding.short.val, 493 std.builtin.Endian.little, 494 ); 495 found = true; 496 break; 497 } 498 } 499 if (!found) return error.InvalidCode; 500 }, 501 .empty => return error.InvalidCode, 502 } 503 // Refill buffer if necessary. 504 while (bit_buffer.num_bits < decoding_bits) { 505 const byte = reader.readByte() catch |err| { 506 if (err != error.EndOfStream) return err; 507 break; 508 }; 509 try bit_buffer.addBits(byte, 8); 510 } 511 } 512 std.debug.assert(bit_buffer.num_bits < decoding_bits); 513 // The compressed data isn't always byte-aligned at the end, so we need to 514 // discard the extra bits we read. 515 const extra_bits = (8 - @as(i32, @intCast(num_bits))) & 7; 516 try bit_buffer.dumpBits(@intCast(extra_bits)); 517 // Decode any data still in the buffer. 518 while (bit_buffer.num_bits > 0) { 519 const shift = decoding_bits - bit_buffer.num_bits; 520 const decoding_idx = (bit_buffer.bits << shift) & decoding_mask; 521 const decoding = decoding_table[decoding_idx]; 522 switch (decoding) { 523 .short => { 524 if (decoding.short.len > bit_buffer.num_bits) { 525 break; 526 } 527 bit_buffer.num_bits -= decoding.short.len; 528 try writer.writeInt( 529 u16, 530 decoding.short.val, 531 std.builtin.Endian.little, 532 ); 533 }, 534 else => return error.InvalidCode, 535 } 536 } 537 } 538 539 test "decode" { 540 // Adapted from tests in the 'exrs' project: https://github.com/johannesvollmer/exrs 541 const allocator = std.testing.allocator; 542 const uncompressed = [_]u16{ 543 3852, 2432, 33635, 49381, 10100, 15095, 62693, 63738, 62359, 5013, 544 7715, 59875, 28182, 34449, 19983, 20399, 63407, 29486, 4877, 26738, 545 44815, 14042, 46091, 48228, 25682, 35412, 7582, 65069, 6632, 54124, 546 13798, 27503, 52154, 61961, 30474, 46880, 39097, 15754, 52897, 42371, 547 54053, 14178, 48276, 34591, 42602, 32126, 42062, 31474, 16274, 55991, 548 2882, 17039, 56389, 20835, 57057, 54081, 3414, 33957, 52584, 10222, 549 25139, 40002, 44980, 1602, 48021, 19703, 6562, 61777, 41582, 201, 550 31253, 51790, 15888, 40921, 3627, 12184, 16036, 26349, 3159, 29002, 551 14535, 50632, 18118, 33583, 18878, 59470, 32835, 9347, 16991, 21303, 552 26263, 8312, 14017, 41777, 43240, 3500, 60250, 52437, 45715, 61520, 553 }; 554 const compressed = [_]u8{ 555 0x10, 0x9, 0xb4, 0xe4, 0x4c, 0xf7, 0xef, 0x42, 0x87, 0x6a, 0xb5, 0xc2, 556 0x34, 0x9e, 0x2f, 0x12, 0xae, 0x21, 0x68, 0xf2, 0xa8, 0x74, 0x37, 0xe1, 557 0x98, 0x14, 0x59, 0x57, 0x2c, 0x24, 0x3b, 0x35, 0x6c, 0x1b, 0x8b, 0xcc, 558 0xe6, 0x13, 0x38, 0xc, 0x8e, 0xe2, 0xc, 0xfe, 0x49, 0x73, 0xbc, 0x2b, 559 0x7b, 0x9, 0x27, 0x79, 0x14, 0xc, 0x94, 0x42, 0xf8, 0x7c, 0x1, 0x8d, 560 0x26, 0xde, 0x87, 0x26, 0x71, 0x50, 0x45, 0xc6, 0x28, 0x40, 0xd5, 0xe, 561 0x8d, 0x8, 0x1e, 0x4c, 0xa4, 0x79, 0x57, 0xf0, 0xc3, 0x6d, 0x5c, 0x6d, 562 0xc0, 563 }; 564 const num_bits = 674; 565 // Build encoding and decoding tables. 566 const frequencies = try countFrequencies(&uncompressed, allocator); 567 defer allocator.free(frequencies); 568 const result = try buildEncodingTable(frequencies, allocator); 569 defer allocator.free(result.encoding_table); 570 const encoding_table = result.encoding_table; 571 const min_code_idx = result.min_code_idx; 572 const max_code_idx = result.max_code_idx; 573 buildCanonicalEncodings(encoding_table); 574 const decoding_table = try buildDecodingTable( 575 encoding_table, 576 min_code_idx, 577 max_code_idx, 578 allocator, 579 ); 580 defer allocator.free(decoding_table); 581 // Decode compressed data 582 var stream = std.io.fixedBufferStream(&compressed); 583 var reader = stream.reader(); 584 var decoded = std.ArrayList(u8).init(allocator); 585 defer decoded.deinit(); 586 var writer = decoded.writer(); 587 try decode(encoding_table, decoding_table, &reader, &writer, num_bits); 588 const output = @as([]u16, @alignCast(std.mem.bytesAsSlice(u16, decoded.items))); 589 try expect(std.mem.eql(u16, &uncompressed, output)); 590 } 591 592 test "encode then decode" { 593 // Adapted from tests in the 'exrs' project: https://github.com/johannesvollmer/exrs 594 const allocator = std.testing.allocator; 595 const uncompressed = [_]u16{ 596 3852, 2432, 33635, 49381, 10100, 15095, 62693, 63738, 62359, 5013, 597 7715, 59875, 28182, 34449, 19983, 20399, 63407, 29486, 4877, 26738, 598 44815, 14042, 46091, 48228, 25682, 35412, 7582, 65069, 6632, 54124, 599 13798, 27503, 52154, 61961, 30474, 46880, 39097, 15754, 52897, 42371, 600 54053, 14178, 48276, 34591, 42602, 32126, 42062, 31474, 16274, 55991, 601 2882, 17039, 56389, 20835, 57057, 54081, 3414, 33957, 52584, 10222, 602 25139, 40002, 44980, 1602, 48021, 19703, 6562, 61777, 41582, 201, 603 31253, 51790, 15888, 40921, 3627, 12184, 16036, 26349, 3159, 29002, 604 14535, 50632, 18118, 33583, 18878, 59470, 32835, 9347, 16991, 21303, 605 26263, 8312, 14017, 41777, 43240, 3500, 60250, 52437, 45715, 61520, 606 }; 607 // Build encoding table for data. 608 const frequencies = try countFrequencies(&uncompressed, allocator); 609 defer allocator.free(frequencies); 610 const result = try buildEncodingTable(frequencies, allocator); 611 const encoding_table = result.encoding_table; 612 const min_code_idx = result.min_code_idx; 613 const max_code_idx = result.max_code_idx; 614 defer allocator.free(encoding_table); 615 buildCanonicalEncodings(encoding_table); 616 // Encode data into output buffer. 617 var encoded = std.ArrayList(u8).init(allocator); 618 defer encoded.deinit(); 619 var writer = encoded.writer(); 620 const num_bits = try encode(&uncompressed, encoding_table, max_code_idx, writer); 621 // Build encoding and decoding tables. 622 buildCanonicalEncodings(encoding_table); 623 const decoding_table = try buildDecodingTable( 624 encoding_table, 625 min_code_idx, 626 max_code_idx, 627 allocator, 628 ); 629 defer allocator.free(decoding_table); 630 // Decode compressed data 631 var stream = std.io.fixedBufferStream(encoded.items); 632 var reader = stream.reader(); 633 var decoded = std.ArrayList(u8).init(allocator); 634 defer decoded.deinit(); 635 writer = decoded.writer(); 636 try decode(encoding_table, decoding_table, &reader, &writer, num_bits); 637 const output = @as([]u16, @alignCast(std.mem.bytesAsSlice(u16, decoded.items))); 638 try expect(std.mem.eql(u16, &uncompressed, output)); 639 } 640 641 const short_zerocode_run = 59; 642 const long_zerocode_run = 63; 643 const shortest_long_run = 2 + long_zerocode_run - short_zerocode_run; 644 const longest_long_run = 255 + shortest_long_run; 645 646 fn packEncodingTable( 647 encoding_table: []Encoding, 648 min_code_idx: u32, 649 max_code_idx: u32, 650 writer: anytype, 651 ) !u32 { 652 var num_bits: u32 = 0; 653 var bit_writer = std.io.bitWriter(std.builtin.Endian.big, writer); 654 var run_len: u8 = 0; 655 for (min_code_idx..max_code_idx + 1) |encoding_idx| { 656 const encoding = encoding_table[encoding_idx]; 657 if ((encoding.len == 0) and (run_len < longest_long_run)) { 658 run_len += 1; 659 continue; 660 } 661 if (run_len >= shortest_long_run) { 662 try bit_writer.writeBits(@as(u6, long_zerocode_run), 6); 663 try bit_writer.writeBits(run_len - shortest_long_run, 8); 664 num_bits += 14; 665 } else if (run_len >= 2) { 666 try bit_writer.writeBits(run_len - 2 + short_zerocode_run, 6); 667 num_bits += 6; 668 } else if (run_len == 1) { 669 try bit_writer.writeBits(@as(u6, 0), 6); 670 num_bits += 6; 671 } 672 run_len = 0; 673 try bit_writer.writeBits(encoding.len, 6); 674 num_bits += 6; 675 } 676 try bit_writer.flushBits(); 677 return (num_bits + 7) / 8; 678 } 679 680 fn readEncodingTable( 681 reader: anytype, 682 min_code_idx: u32, 683 max_code_idx: u32, 684 allocator: std.mem.Allocator, 685 ) ![]Encoding { 686 const encoding_table = try allocator.alloc(Encoding, encoding_table_size); 687 @memset(encoding_table, .{ .val = 0, .len = 0 }); 688 var bit_reader = std.io.bitReader(std.builtin.Endian.big, reader); 689 var encoding_idx = min_code_idx; 690 while (encoding_idx <= max_code_idx) { 691 const encoding_len = try bit_reader.readBitsNoEof(u6, 6); 692 if (encoding_len == long_zerocode_run) { 693 const run_len = try bit_reader.readBitsNoEof(u8, 8) + shortest_long_run; 694 if (encoding_idx + run_len > max_code_idx + 1) { 695 return error.EncodingTableTooLong; 696 } 697 encoding_idx += run_len; 698 continue; 699 } 700 if (encoding_len >= short_zerocode_run) { 701 const run_len = encoding_len - short_zerocode_run + 2; 702 if (encoding_idx + run_len > max_code_idx + 1) { 703 return error.EncodingTableTooLong; 704 } 705 encoding_idx += run_len; 706 continue; 707 } 708 encoding_table[encoding_idx].len = encoding_len; 709 encoding_idx += 1; 710 } 711 return encoding_table; 712 } 713 714 test "pack and unpack encoding table" { 715 const allocator = std.testing.allocator; 716 const data = [_]u16{ 717 1, 1, 1, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 718 5, 5, 6, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 719 }; 720 // Build encoding table. 721 const frequencies = try countFrequencies(&data, allocator); 722 defer allocator.free(frequencies); 723 const result = try buildEncodingTable(frequencies, allocator); 724 defer allocator.free(result.encoding_table); 725 const encoding_table = result.encoding_table; 726 const min_code_idx = result.min_code_idx; 727 const max_code_idx = result.max_code_idx; 728 buildCanonicalEncodings(encoding_table); 729 // Pack encoding table. 730 var packed_table = std.ArrayList(u8).init(allocator); 731 defer packed_table.deinit(); 732 const writer = packed_table.writer(); 733 _ = try packEncodingTable( 734 encoding_table, 735 min_code_idx, 736 max_code_idx, 737 writer, 738 ); 739 // Unpack encoding table. 740 var stream = std.io.fixedBufferStream(packed_table.items); 741 const reader = stream.reader(); 742 const unpacked = try readEncodingTable( 743 reader, 744 min_code_idx, 745 max_code_idx, 746 allocator, 747 ); 748 defer allocator.free(unpacked); 749 buildCanonicalEncodings(unpacked); 750 // The encoding tables should be identical. 751 for (encoding_table, unpacked) |e1, e2| { 752 try expect(std.meta.eql(e1, e2)); 753 } 754 } 755 756 const EncodingMetadata = packed struct { 757 min_code_idx: u32, 758 max_code_idx: u32, 759 table_size: u32, 760 num_bits: u32, 761 padding: u32 = 0, 762 }; 763 764 pub fn compress( 765 src: []const u16, 766 dst: *std.io.StreamSource, 767 allocator: std.mem.Allocator, 768 ) !void { 769 const frequencies = try countFrequencies(src, allocator); 770 defer allocator.free(frequencies); 771 const result = try buildEncodingTable(frequencies, allocator); 772 defer allocator.free(result.encoding_table); 773 const encoding_table = result.encoding_table; 774 const min_code_idx = result.min_code_idx; 775 const max_code_idx = result.max_code_idx; 776 buildCanonicalEncodings(encoding_table); 777 778 const metadata_pos = try dst.getPos(); 779 try dst.seekBy(5 * @sizeOf(u32)); 780 const table_size = try packEncodingTable( 781 encoding_table, 782 min_code_idx, 783 max_code_idx, 784 dst.writer(), 785 ); 786 const num_bits = try encode( 787 src, 788 encoding_table, 789 max_code_idx, 790 dst.writer(), 791 ); 792 try dst.seekTo(metadata_pos); 793 try dst.writer().writeInt(u32, min_code_idx, std.builtin.Endian.little); 794 try dst.writer().writeInt(u32, max_code_idx, std.builtin.Endian.little); 795 try dst.writer().writeInt(u32, table_size, std.builtin.Endian.little); 796 try dst.writer().writeInt(u32, num_bits, std.builtin.Endian.little); 797 try dst.writer().writeInt(u32, 0, std.builtin.Endian.little); 798 } 799 800 pub fn decompress( 801 reader: anytype, 802 writer: anytype, 803 allocator: std.mem.Allocator, 804 ) !void { 805 const min_code_idx = try reader.readInt(u32, std.builtin.Endian.little); 806 const max_code_idx = try reader.readInt(u32, std.builtin.Endian.little); 807 const table_size = try reader.readInt(u32, std.builtin.Endian.little); 808 const num_bits = try reader.readInt(u32, std.builtin.Endian.little); 809 const padding = try reader.readInt(u32, std.builtin.Endian.little); 810 _ = padding; 811 _ = table_size; 812 const unpacked = try readEncodingTable( 813 reader, 814 min_code_idx, 815 max_code_idx, 816 allocator, 817 ); 818 defer allocator.free(unpacked); 819 buildCanonicalEncodings(unpacked); 820 const decoding_table = try buildDecodingTable( 821 unpacked, 822 min_code_idx, 823 max_code_idx, 824 allocator, 825 ); 826 defer allocator.free(decoding_table); 827 const num_bytes = ((num_bits + 7) & 0xFFFFFFF8) >> 3; 828 var limited_reader = std.io.limitedReader(reader, num_bytes); 829 try decode(unpacked, decoding_table, limited_reader.reader(), writer, num_bits); 830 } 831 832 test "compress then decompress" { 833 const allocator = std.testing.allocator; 834 const src = [_]u16{ 835 1, 1, 1, 2, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 836 5, 5, 6, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 837 }; 838 const dst_buffer = try allocator.alloc(u8, 256); 839 defer allocator.free(dst_buffer); 840 @memset(dst_buffer, 0); 841 var dst = std.io.StreamSource{ 842 .buffer = std.io.fixedBufferStream(dst_buffer), 843 }; 844 try compress(&src, &dst, allocator); 845 846 try dst.seekTo(0); 847 848 const decoded_buffer = try allocator.alloc(u8, 256); 849 defer allocator.free(decoded_buffer); 850 @memset(decoded_buffer, 0); 851 var decoded = std.io.StreamSource{ 852 .buffer = std.io.fixedBufferStream(decoded_buffer), 853 }; 854 try decompress( 855 dst.reader(), 856 decoded.writer(), 857 allocator, 858 ); 859 const output = @as([]u16, @alignCast(std.mem.bytesAsSlice(u16, decoded_buffer))); 860 try expect(std.mem.eql(u16, output[0..src.len], &src)); 861 }