LCOV - code coverage report
Current view: top level - lib/compression - lzxpress_huffman.c (source / functions) Hit Total Coverage
Test: coverage report for recycleplus df22b230 Lines: 0 737 0.0 %
Date: 2024-02-14 10:14:15 Functions: 0 32 0.0 %

          Line data    Source code
       1             : /*
       2             :  * Samba compression library - LGPLv3
       3             :  *
       4             :  * Copyright © Catalyst IT 2022
       5             :  *
       6             :  * Written by Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
       7             :  *        and Joseph Sutton   <josephsutton@catalyst.net.nz>
       8             :  *
       9             :  *  ** NOTE! The following LGPL license applies to this file.
      10             :  *  ** It does NOT imply that all of Samba is released under the LGPL
      11             :  *
      12             :  *  This library is free software; you can redistribute it and/or
      13             :  *  modify it under the terms of the GNU Lesser General Public
      14             :  *  License as published by the Free Software Foundation; either
      15             :  *  version 3 of the License, or (at your option) any later version.
      16             :  *
      17             :  *  This library is distributed in the hope that it will be useful,
      18             :  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
      19             :  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      20             :  *  Lesser General Public License for more details.
      21             :  *
      22             :  *  You should have received a copy of the GNU Lesser General Public
      23             :  *  License along with this library; if not, see <http://www.gnu.org/licenses/>.
      24             :  */
      25             : 
      26             : #include <talloc.h>
      27             : 
      28             : #include "replace.h"
      29             : #include "lzxpress_huffman.h"
      30             : #include "lib/util/stable_sort.h"
      31             : #include "lib/util/debug.h"
      32             : #include "lib/util/byteorder.h"
      33             : #include "lib/util/bytearray.h"
      34             : 
      35             : /*
      36             :  * DEBUG_NO_LZ77_MATCHES toggles the encoding of matches as matches. If it is
      37             :  * false the potential match is written as a series of literals, which is a
      38             :  * valid but usually inefficient encoding. This is useful for isolating a
      39             :  * problem to either the LZ77 or the Huffman stage.
      40             :  */
      41             : #ifndef DEBUG_NO_LZ77_MATCHES
      42             : #define DEBUG_NO_LZ77_MATCHES false
      43             : #endif
      44             : 
      45             : /*
      46             :  * DEBUG_HUFFMAN_TREE forces the drawing of ascii art huffman trees during
      47             :  * compression and decompression.
      48             :  *
      49             :  * These trees will also be drawn at DEBUG level 10, but that doesn't work
      50             :  * with cmocka tests.
      51             :  */
      52             : #ifndef DEBUG_HUFFMAN_TREE
      53             : #define DEBUG_HUFFMAN_TREE false
      54             : #endif
      55             : 
      56             : #if DEBUG_HUFFMAN_TREE
      57             : #define DBG(...) fprintf(stderr, __VA_ARGS__)
      58             : #else
      59             : #define DBG(...) DBG_INFO(__VA_ARGS__)
      60             : #endif
      61             : 
      62             : 
      63             : #define LZXPRESS_ERROR -1LL
      64             : 
      65             : /*
      66             :  * We won't encode a match length longer than MAX_MATCH_LENGTH.
      67             :  *
      68             :  * Reports are that Windows has a limit at 64M.
      69             :  */
      70             : #define MAX_MATCH_LENGTH (64 * 1024 * 1024)
      71             : 
      72             : 
      73             : struct bitstream {
      74             :         const uint8_t *bytes;
      75             :         size_t byte_pos;
      76             :         size_t byte_size;
      77             :         uint32_t bits;
      78             :         int remaining_bits;
      79             :         uint16_t *table;
      80             : };
      81             : 
      82             : 
      83             : #if ! defined __has_builtin
      84             : #define __has_builtin(x) 0
      85             : #endif
      86             : 
      87             : /*
      88             :  * bitlen_nonzero_16() returns the bit number of the most significant bit, or
      89             :  * put another way, the integer log base 2. Log(0) is undefined; the argument
      90             :  * has to be non-zero!
      91             :  * 1     -> 0
      92             :  * 2,3   -> 1
      93             :  * 4-7   -> 2
      94             :  * 1024  -> 10, etc
      95             :  *
      96             :  * Probably this is handled by a compiler intrinsic function that maps to a
      97             :  * dedicated machine instruction.
      98             :  */
      99             : 
     100           0 : static inline int bitlen_nonzero_16(uint16_t x)
     101             : {
     102             : #if  __has_builtin(__builtin_clz)
     103             : 
     104             :         /* __builtin_clz returns the number of leading zeros */
     105             :         return (sizeof(unsigned int) * CHAR_BIT) - 1
     106           0 :                 - __builtin_clz((unsigned int) x);
     107             : 
     108             : #else
     109             : 
     110             :         int count = -1;
     111             :         while(x) {
     112             :                 x >>= 1;
     113             :                 count++;
     114             :         }
     115             :         return count;
     116             : 
     117             : #endif
     118             : }
     119             : 
     120             : 
     121             : struct lzxhuff_compressor_context {
     122             :         const uint8_t *input_bytes;
     123             :         size_t input_size;
     124             :         size_t input_pos;
     125             :         size_t prev_block_pos;
     126             :         uint8_t *output;
     127             :         size_t available_size;
     128             :         size_t output_pos;
     129             : };
     130             : 
     131           0 : static int compare_huffman_node_count(struct huffman_node *a,
     132             :                                       struct huffman_node *b)
     133             : {
     134           0 :         return a->count - b->count;
     135             : }
     136             : 
     137           0 : static int compare_huffman_node_depth(struct huffman_node *a,
     138             :                                       struct huffman_node *b)
     139             : {
     140           0 :         int c = a->depth - b->depth;
     141           0 :         if (c != 0) {
     142           0 :                 return c;
     143             :         }
     144           0 :         return (int)a->symbol - (int)b->symbol;
     145             : }
     146             : 
     147             : 
     148             : #define HASH_MASK ((1 << LZX_HUFF_COMP_HASH_BITS) - 1)
     149             : 
     150           0 : static inline uint16_t three_byte_hash(const uint8_t *bytes)
     151             : {
     152             :         /*
     153             :          * MS-XCA says "three byte hash", but does not specify it.
     154             :          *
     155             :          * This one is just cobbled together, but has quite good distribution
     156             :          * in the 12-14 bit forms, which is what we care about most.
     157             :          * e.g: 13 bit: median 2048, min 2022, max 2074, stddev 6.0
     158             :          */
     159           0 :         uint16_t a = bytes[0];
     160           0 :         uint16_t b = bytes[1] ^ 0x2e;
     161           0 :         uint16_t c = bytes[2] ^ 0x55;
     162           0 :         uint16_t ca = c - a;
     163           0 :         uint16_t d = ((a + b) << 8) ^ (ca << 5) ^ (c + b) ^ (0xcab + a);
     164           0 :         return d & HASH_MASK;
     165             : }
     166             : 
     167             : 
     168           0 : static inline uint16_t encode_match(size_t len, size_t offset)
     169             : {
     170           0 :         uint16_t code = 256;
     171           0 :         code |= MIN(len - 3, 15);
     172           0 :         code |= bitlen_nonzero_16(offset) << 4;
     173           0 :         return code;
     174             : }
     175             : 
     176             : /*
     177             :  * debug_huffman_tree() uses debug_huffman_tree_print() to draw the Huffman
     178             :  * tree in ascii art.
     179             :  *
     180             :  * Note that the Huffman tree is probably not the same as that implied by the
     181             :  * canonical Huffman encoding that is finally used. That tree would be the
     182             :  * same shape, but with the left and right toggled to sort the branches by
     183             :  * length, after which the symbols for each length sorted by value.
     184             :  */
     185             : 
     186           0 : static void debug_huffman_tree_print(struct huffman_node *node,
     187             :                                      int *trail, int depth)
     188             : {
     189           0 :         if (node->left == NULL) {
     190             :                 /* time to print a row */
     191             :                 int j;
     192           0 :                 bool branched = false;
     193             :                 int row[17];
     194             :                 char c[100];
     195           0 :                 int s = node->symbol;
     196             :                 char code[17];
     197           0 :                 if (depth > 15) {
     198           0 :                         fprintf(stderr,
     199             :                                 " \033[1;31m Max depth exceeded! (%d)\033[0m "
     200             :                                 " symbol %#3x claimed depth %d count %d\n",
     201           0 :                                 depth, node->symbol, node->depth, node->count);
     202           0 :                         return;
     203             :                 }
     204           0 :                 for (j = depth - 1; j >= 0; j--) {
     205           0 :                         if (branched) {
     206           0 :                                 if (trail[j] == -1) {
     207           0 :                                         row[j] = -3;
     208             :                                 } else {
     209           0 :                                         row[j] = -2;
     210             :                                 }
     211           0 :                         } else if (trail[j] == -1) {
     212           0 :                                 row[j] = -1;
     213           0 :                                 branched = true;
     214             :                         } else {
     215           0 :                                 row[j] = trail[j];
     216             :                         }
     217             :                 }
     218           0 :                 for (j = 0; j < depth; j++) {
     219           0 :                         switch (row[j]) {
     220           0 :                         case -3:
     221           0 :                                 code[j] = '1';
     222           0 :                                 fprintf(stderr, "        ");
     223           0 :                                 break;
     224           0 :                         case -2:
     225           0 :                                 code[j] = '0';
     226           0 :                                 fprintf(stderr, "      │ ");
     227           0 :                                 break;
     228           0 :                         case -1:
     229           0 :                                 code[j] = '1';
     230           0 :                                 fprintf(stderr, "      ╰─");
     231           0 :                                 break;
     232           0 :                         default:
     233           0 :                                 code[j] = '0';
     234           0 :                                 fprintf(stderr, "%5d─┬─", row[j]);
     235           0 :                                 break;
     236             :                         }
     237             :                 }
     238           0 :                 code[depth] = 0;
     239           0 :                 if (s < 32) {
     240           0 :                         snprintf(c, sizeof(c),
     241             :                                 "\033[1;32m%02x\033[0m \033[1;33m%c%c%c\033[0m",
     242             :                                  s,
     243             :                                  0xE2, 0x90, 0x80 + s); /* utf-8 for symbol */
     244           0 :                 }  else if (s < 127) {
     245           0 :                         snprintf(c, sizeof(c),
     246             :                                  "\033[1;32m%2x\033[0m '\033[10;32m%c\033[0m'",
     247             :                                  s, s);
     248           0 :                 } else if (s < 256) {
     249           0 :                         snprintf(c, sizeof(c), "\033[1;32m%2x\033[0m", s);
     250             :                 } else {
     251           0 :                         uint16_t len = (s & 15) + 3;
     252           0 :                         uint16_t dbits = ((s >> 4) & 15) + 1;
     253           0 :                         snprintf(c, sizeof(c),
     254             :                                  " \033[0;33mlen:%2d%s, "
     255             :                                  "dist:%d-%d \033[0m \033[1;32m%3x\033[0m%s",
     256             :                                  len,
     257             :                                  len == 18 ? "+" : "",
     258           0 :                                  1 << (dbits - 1),
     259           0 :                                  (1 << dbits) - 1,
     260             :                                  s,
     261             :                                  s == 256 ? " \033[1;31mEOF\033[0m" : "");
     262             : 
     263             :                 }
     264             : 
     265           0 :                 fprintf(stderr, "──%5d %s \033[2;37m%s\033[0m\n",
     266             :                         node->count, c, code);
     267           0 :                 return;
     268             :         }
     269           0 :         trail[depth] = node->count;
     270           0 :         debug_huffman_tree_print(node->left, trail, depth + 1);
     271           0 :         trail[depth] = -1;
     272           0 :         debug_huffman_tree_print(node->right, trail, depth + 1);
     273             : }
     274             : 
     275             : 
     276             : /*
     277             :  * If DEBUG_HUFFMAN_TREE is defined true, debug_huffman_tree()
     278             :  * will print a tree looking something like this:
     279             :  *
     280             :  *     7─┬───    3  len:18+, dist:1-1  10f 0
     281             :  *       ╰─    4─┬─    2─┬───    1 61 'a' 100
     282             :  *               │       ╰───    1 62 'b' 101
     283             :  *               ╰─    2─┬───    1 63 'c' 110
     284             :  *                       ╰───    1  len: 3, dist:1-1  100 EOF 111
     285             :  *
     286             :  * This is based off a Huffman root node, and the tree may not be the same as
     287             :  * the canonical tree.
     288             :  */
     289           0 : static void debug_huffman_tree(struct huffman_node *root)
     290             : {
     291             :         int trail[17];
     292           0 :         debug_huffman_tree_print(root, trail, 0);
     293           0 : }
     294             : 
     295             : 
     296             : /*
     297             :  * If DEBUG_HUFFMAN_TREE is defined true, debug_huffman_tree_from_table()
     298             :  * will print something like this based on a decoding symbol table.
     299             :  *
     300             :  *  Tree from decoding table 9 nodes → 5 codes
     301             :  * 10000─┬─── 5000  len:18+, dist:1-1  10f 0
     302             :  *       ╰─ 5000─┬─ 2500─┬─── 1250 61 'a' 100
     303             :  *               │       ╰─── 1250 62 'b' 101
     304             :  *               ╰─ 2500─┬─── 1250 63 'c' 110
     305             :  *                       ╰─── 1250  len: 3, dist:1-1  100 EOF 111
     306             :  *
     307             :  * This is the canonical form of the Huffman tree where the actual counts
     308             :  * aren't known (we use "10000" to help indicate relative frequencies).
     309             :  */
     310           0 : static void debug_huffman_tree_from_table(uint16_t *table)
     311             : {
     312             :         int trail[17];
     313           0 :         struct huffman_node nodes[1024] = {{0}};
     314             :         uint16_t codes[1024];
     315           0 :         size_t n = 1;
     316           0 :         size_t i = 0;
     317           0 :         codes[0] = 0;
     318           0 :         nodes[0].count = 10000;
     319             : 
     320           0 :         while (i < n) {
     321           0 :                 uint16_t index = codes[i];
     322           0 :                 struct huffman_node *node = &nodes[i];
     323           0 :                 if (table[index] == 0xffff) {
     324             :                         /* internal node */
     325           0 :                         index <<= 1;
     326             :                         /* left */
     327           0 :                         index++;
     328           0 :                         codes[n] = index;
     329           0 :                         node->left = nodes + n;
     330           0 :                         nodes[n].count = node->count >> 1;
     331           0 :                         n++;
     332             :                         /*right*/
     333           0 :                         index++;
     334           0 :                         codes[n] = index;
     335           0 :                         node->right = nodes + n;
     336           0 :                         nodes[n].count = node->count >> 1;
     337           0 :                         n++;
     338             :                 } else {
     339             :                         /* leaf node */
     340           0 :                         node->symbol = table[index] & 511;
     341             :                 }
     342           0 :                 i++;
     343             :         }
     344             : 
     345           0 :         fprintf(stderr,
     346             :                 "\033[1;34m Tree from decoding table\033[0m "
     347             :                 "%zu nodes → %zu codes\n",
     348           0 :                 n, (n + 1) / 2);
     349           0 :         debug_huffman_tree_print(nodes, trail, 0);
     350           0 : }
     351             : 
     352             : 
     353           0 : static bool depth_walk(struct huffman_node *n, uint32_t depth)
     354             : {
     355             :         bool ok;
     356           0 :         if (n->left == NULL) {
     357             :                 /* this is a leaf, record the depth */
     358           0 :                 n->depth = depth;
     359           0 :                 return true;
     360             :         }
     361           0 :         if (depth > 14) {
     362           0 :                 return false;
     363             :         }
     364           0 :         ok = (depth_walk(n->left, depth + 1) &&
     365           0 :               depth_walk(n->right, depth + 1));
     366             : 
     367           0 :         return ok;
     368             : }
     369             : 
     370             : 
     371           0 : static bool check_and_record_depths(struct huffman_node *root)
     372             : {
     373           0 :         return depth_walk(root, 0);
     374             : }
     375             : 
     376             : 
     377           0 : static bool encode_values(struct huffman_node *leaves,
     378             :                           size_t n_leaves,
     379             :                           uint16_t symbol_values[512])
     380             : {
     381             :         size_t i;
     382             :         /*
     383             :          * See, we have a leading 1 in our internal code representation, which
     384             :          * indicates the code length.
     385             :          */
     386           0 :         uint32_t code = 1;
     387           0 :         uint32_t code_len = 0;
     388           0 :         memset(symbol_values, 0, sizeof(uint16_t) * 512);
     389           0 :         for (i = 0; i < n_leaves; i++) {
     390           0 :                 code <<= leaves[i].depth - code_len;
     391           0 :                 code_len = leaves[i].depth;
     392             : 
     393           0 :                 symbol_values[leaves[i].symbol] = code;
     394           0 :                 code++;
     395             :         }
     396             :         /*
     397             :          * The last code should be 11111... with code_len + 1 ones. The final
     398             :          * code++ will wrap this round to 1000... with code_len + 1 zeroes.
     399             :          */
     400             : 
     401           0 :         if (code != 2 << code_len) {
     402           0 :                 return false;
     403             :         }
     404           0 :         return true;
     405             : }
     406             : 
     407             : 
     408           0 : static int generate_huffman_codes(struct huffman_node *leaf_nodes,
     409             :                                   struct huffman_node *internal_nodes,
     410             :                                   uint16_t symbol_values[512])
     411             : {
     412           0 :         size_t head_leaf = 0;
     413           0 :         size_t head_branch = 0;
     414           0 :         size_t tail_branch = 0;
     415           0 :         struct huffman_node *huffman_root = NULL;
     416             :         size_t i, j;
     417           0 :         size_t n_leaves = 0;
     418             : 
     419             :         /*
     420             :          * Before we sort the nodes, we can eliminate the unused ones.
     421             :          */
     422           0 :         for (i = 0; i < 512; i++) {
     423           0 :                 if (leaf_nodes[i].count) {
     424           0 :                         leaf_nodes[n_leaves] = leaf_nodes[i];
     425           0 :                         n_leaves++;
     426             :                 }
     427             :         }
     428           0 :         if (n_leaves == 0) {
     429           0 :                 return LZXPRESS_ERROR;
     430             :         }
     431           0 :         if (n_leaves == 1) {
     432             :                 /*
     433             :                  * There is *almost* no way this should happen, and it would
     434             :                  * ruin the tree (because the shortest possible codes are 1
     435             :                  * bit long, and there are two of them).
     436             :                  *
     437             :                  * The only way to get here is in an internal block in a
     438             :                  * 3-or-more block message (i.e. > 128k), which consists
     439             :                  * entirely of a match starting in the previous block (if it
     440             :                  * was the end block, it would have the EOF symbol).
     441             :                  *
     442             :                  * What we do is add a dummy symbol which is this one XOR 256.
     443             :                  * It won't be used in the stream but will balance the tree.
     444             :                  */
     445           0 :                 leaf_nodes[1] = leaf_nodes[0];
     446           0 :                 leaf_nodes[1].symbol ^= 0x100;
     447           0 :                 n_leaves = 2;
     448             :         }
     449             : 
     450             :         /* note, in sort we're using internal_nodes as auxillary space */
     451           0 :         stable_sort(leaf_nodes,
     452             :                     internal_nodes,
     453             :                     n_leaves,
     454             :                     sizeof(struct huffman_node),
     455             :                     (samba_compare_fn_t)compare_huffman_node_count);
     456             : 
     457             :         /*
     458             :          * This outer loop is for re-quantizing the counts if the tree is too
     459             :          * tall (>15), which we need to do because the final encoding can't
     460             :          * express a tree that deep.
     461             :          *
     462             :          * In theory, this should be a 'while (true)' loop, but we chicken
     463             :          * out with 10 iterations, just in case.
     464             :          *
     465             :          * In practice it will almost always resolve in the first round; if
     466             :          * not then, in the second or third. Remember we'll looking at 64k or
     467             :          * less, so the rarest we can have is 1 in 64k; each round of
     468             :          * quantization effecively doubles its frequency to 1 in 32k, 1 in
     469             :          * 16k, etc, until we're treating the rare symbol as actually quite
     470             :          * common.
     471             :          */
     472           0 :         for (j = 0; j < 10; j++) {
     473             :                 bool less_than_15_bits;
     474           0 :                 while (true) {
     475           0 :                         struct huffman_node *a = NULL;
     476           0 :                         struct huffman_node *b = NULL;
     477           0 :                         size_t leaf_len = n_leaves - head_leaf;
     478           0 :                         size_t internal_len = tail_branch - head_branch;
     479             : 
     480           0 :                         if (leaf_len + internal_len == 1) {
     481             :                                 /*
     482             :                                  * We have the complete tree. The root will be
     483             :                                  * an internal node unless there is just one
     484             :                                  * symbol, which is already impossible.
     485             :                                  */
     486           0 :                                 if (unlikely(leaf_len == 1)) {
     487           0 :                                         return LZXPRESS_ERROR;
     488             :                                 } else {
     489           0 :                                         huffman_root = \
     490           0 :                                                 &internal_nodes[head_branch];
     491             :                                 }
     492           0 :                                 break;
     493             :                         }
     494             :                         /*
     495             :                          * We know here we have at least two nodes, and we
     496             :                          * want to select the two lowest scoring ones. Those
     497             :                          * have to be either a) the head of each queue, or b)
     498             :                          * the first two nodes of either queue.
     499             :                          *
     500             :                          * The complicating factors are: a) we need to check
     501             :                          * the length of each queue, and b) in the case of
     502             :                          * ties, we prefer to pair leaves with leaves.
     503             :                          *
     504             :                          * Note a complication we don't have: the leaf node
     505             :                          * queue never grows, and the subtree queue starts
     506             :                          * empty and cannot grow beyond n - 1. It feeds on
     507             :                          * itself. We don't need to think about overflow.
     508             :                          */
     509           0 :                         if (leaf_len == 0) {
     510             :                                 /* two from subtrees */
     511           0 :                                 a = &internal_nodes[head_branch];
     512           0 :                                 b = &internal_nodes[head_branch + 1];
     513           0 :                                 head_branch += 2;
     514           0 :                         } else if (internal_len == 0) {
     515             :                                 /* two from nodes */
     516           0 :                                 a = &leaf_nodes[head_leaf];
     517           0 :                                 b = &leaf_nodes[head_leaf + 1];
     518           0 :                                 head_leaf += 2;
     519           0 :                         } else if (leaf_len == 1 && internal_len == 1) {
     520             :                                 /* one of each */
     521           0 :                                 a = &leaf_nodes[head_leaf];
     522           0 :                                 b = &internal_nodes[head_branch];
     523           0 :                                 head_branch++;
     524           0 :                                 head_leaf++;
     525             :                         } else {
     526             :                                 /*
     527             :                                  * Take the lowest head, twice, checking for
     528             :                                  * length after taking the first one.
     529             :                                  */
     530           0 :                                 if (leaf_nodes[head_leaf].count >
     531           0 :                                     internal_nodes[head_branch].count) {
     532           0 :                                         a = &internal_nodes[head_branch];
     533           0 :                                         head_branch++;
     534           0 :                                         if (internal_len == 1) {
     535           0 :                                                 b = &leaf_nodes[head_leaf];
     536           0 :                                                 head_leaf++;
     537           0 :                                                 goto done;
     538             :                                         }
     539             :                                 } else {
     540           0 :                                         a = &leaf_nodes[head_leaf];
     541           0 :                                         head_leaf++;
     542           0 :                                         if (leaf_len == 1) {
     543           0 :                                                 b = &internal_nodes[head_branch];
     544           0 :                                                 head_branch++;
     545           0 :                                                 goto done;
     546             :                                         }
     547             :                                 }
     548             :                                 /* the other node */
     549           0 :                                 if (leaf_nodes[head_leaf].count >
     550           0 :                                     internal_nodes[head_branch].count) {
     551           0 :                                         b = &internal_nodes[head_branch];
     552           0 :                                         head_branch++;
     553             :                                 } else {
     554           0 :                                         b = &leaf_nodes[head_leaf];
     555           0 :                                         head_leaf++;
     556             :                                 }
     557             :                         }
     558           0 :                 done:
     559             :                         /*
     560             :                          * Now we add a new node to the subtrees list that
     561             :                          * combines the score of node_a and node_b, and points
     562             :                          * to them as children.
     563             :                          */
     564           0 :                         internal_nodes[tail_branch].count = a->count + b->count;
     565           0 :                         internal_nodes[tail_branch].left = a;
     566           0 :                         internal_nodes[tail_branch].right = b;
     567           0 :                         tail_branch++;
     568           0 :                         if (tail_branch == n_leaves) {
     569             :                                 /*
     570             :                                  * We're not getting here, no way, never ever.
     571             :                                  * Unless we made a terible mistake.
     572             :                                  *
     573             :                                  * That is, in a binary tree with n leaves,
     574             :                                  * there are ALWAYS n-1 internal nodes.
     575             :                                  */
     576           0 :                                 return LZXPRESS_ERROR;
     577             :                         }
     578             :                 }
     579           0 :                 if (CHECK_DEBUGLVL(10) || DEBUG_HUFFMAN_TREE) {
     580           0 :                         debug_huffman_tree(huffman_root);
     581             :                 }
     582             :                 /*
     583             :                  * We have a tree, and need to turn it into a lookup table,
     584             :                  * and see if it is shallow enough (<= 15).
     585             :                  */
     586           0 :                 less_than_15_bits = check_and_record_depths(huffman_root);
     587           0 :                 if (less_than_15_bits) {
     588             :                         /*
     589             :                          * Now the leaf nodes know how deep they are, and we
     590             :                          * no longer need the internal nodes.
     591             :                          *
     592             :                          * We need to sort the nodes of equal depth, so that
     593             :                          * they are sorted by depth first, and symbol value
     594             :                          * second. The internal_nodes can again be auxillary
     595             :                          * memory.
     596             :                          */
     597           0 :                         stable_sort(
     598             :                                 leaf_nodes,
     599             :                                 internal_nodes,
     600             :                                 n_leaves,
     601             :                                 sizeof(struct huffman_node),
     602             :                                 (samba_compare_fn_t)compare_huffman_node_depth);
     603             : 
     604           0 :                         encode_values(leaf_nodes, n_leaves, symbol_values);
     605             : 
     606           0 :                         return n_leaves;
     607             :                 }
     608             : 
     609             :                 /*
     610             :                  * requantize by halfing and rounding up, so that small counts
     611             :                  * become relatively bigger. This will lead to a flatter tree.
     612             :                  */
     613           0 :                 for (i = 0; i < n_leaves; i++) {
     614           0 :                         leaf_nodes[i].count >>= 1;
     615           0 :                         leaf_nodes[i].count += 1;
     616             :                 }
     617           0 :                 head_leaf = 0;
     618           0 :                 head_branch = 0;
     619           0 :                 tail_branch = 0;
     620             :         }
     621           0 :         return LZXPRESS_ERROR;
     622             : }
     623             : 
     624             : /*
     625             :  * LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS is how far ahead to search in the
     626             :  * circular hash table for a match, before we give up. A bigger number will
     627             :  * generally lead to better but slower compression, but a stupidly big number
     628             :  * will just be worse.
     629             :  *
     630             :  * If you're fiddling with this, consider also fiddling with
     631             :  * LZX_HUFF_COMP_HASH_BITS.
     632             :  */
     633             : #define LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS 5
     634             : 
     635           0 : static inline void store_match(uint16_t *hash_table,
     636             :                                uint16_t h,
     637             :                                uint16_t offset)
     638             : {
     639             :         int i;
     640           0 :         uint16_t o = hash_table[h];
     641             :         uint16_t h2;
     642             :         uint16_t worst_h;
     643             :         int worst_score;
     644             : 
     645           0 :         if (o == 0xffff) {
     646             :                 /* there is nothing there yet */
     647           0 :                 hash_table[h] = offset;
     648           0 :                 return;
     649             :         }
     650           0 :         for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     651           0 :                 h2 = (h + i) & HASH_MASK;
     652           0 :                 if (hash_table[h2] == 0xffff) {
     653           0 :                         hash_table[h2] = offset;
     654           0 :                         return;
     655             :                 }
     656             :         }
     657             :         /*
     658             :          * There are no slots, but we really want to store this, so we'll kick
     659             :          * out the one with the longest distance.
     660             :          */
     661           0 :         worst_h = h;
     662           0 :         worst_score = offset - o;
     663           0 :         for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     664             :                 int score;
     665           0 :                 h2 = (h + i) & HASH_MASK;
     666           0 :                 o = hash_table[h2];
     667           0 :                 score = offset - o;
     668           0 :                 if (score > worst_score) {
     669           0 :                         worst_score = score;
     670           0 :                         worst_h = h2;
     671             :                 }
     672             :         }
     673           0 :         hash_table[worst_h] = offset;
     674             : }
     675             : 
     676             : 
     677             : /*
     678             :  * Yes, struct match looks a lot like a DATA_BLOB.
     679             :  */
     680             : struct match {
     681             :         const uint8_t *there;
     682             :         size_t length;
     683             : };
     684             : 
     685             : 
     686           0 : static inline struct match lookup_match(uint16_t *hash_table,
     687             :                                         uint16_t h,
     688             :                                         const uint8_t *data,
     689             :                                         const uint8_t *here,
     690             :                                         size_t max_len)
     691             : {
     692             :         int i;
     693           0 :         uint16_t o = hash_table[h];
     694             :         uint16_t h2;
     695             :         size_t len;
     696           0 :         const uint8_t *there = NULL;
     697           0 :         struct match best = {0};
     698             : 
     699           0 :         for (i = 0; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     700           0 :                 h2 = (h + i) & HASH_MASK;
     701           0 :                 o = hash_table[h2];
     702           0 :                 if (o == 0xffff) {
     703             :                         /*
     704             :                          * in setting this, we would never have stepped over
     705             :                          * an 0xffff, so we won't now.
     706             :                          */
     707           0 :                         break;
     708             :                 }
     709           0 :                 there = data + o;
     710           0 :                 if (here - there > 65534 || there > here) {
     711           0 :                         continue;
     712             :                 }
     713             : 
     714             :                 /*
     715             :                  * When we already have a long match, we can try to avoid
     716             :                  * measuring out another long, but shorter match.
     717             :                  */
     718           0 :                 if (best.length > 1000 &&
     719           0 :                     there[best.length - 1] != best.there[best.length - 1]) {
     720           0 :                         continue;
     721             :                 }
     722             : 
     723           0 :                 for (len = 0;
     724           0 :                      len < max_len && here[len] == there[len];
     725           0 :                      len++) {
     726             :                         /* counting */
     727             :                 }
     728           0 :                 if (len > 2) {
     729             :                         /*
     730             :                          * As a tiebreaker, we prefer the closer match which
     731             :                          * is likely to encode smaller (and certainly no worse).
     732             :                          */
     733           0 :                         if (len > best.length ||
     734           0 :                             (len == best.length && there > best.there)) {
     735           0 :                                 best.length = len;
     736           0 :                                 best.there = there;
     737             :                         }
     738             :                 }
     739             :         }
     740           0 :         return best;
     741             : }
     742             : 
     743             : 
     744             : 
     745           0 : static ssize_t lz77_encode_block(struct lzxhuff_compressor_context *cmp_ctx,
     746             :                                  struct lzxhuff_compressor_mem *cmp_mem,
     747             :                                  uint16_t *hash_table,
     748             :                                  uint16_t *prev_hash_table)
     749             : {
     750           0 :         uint16_t *intermediate = cmp_mem->intermediate;
     751           0 :         struct huffman_node *leaf_nodes = cmp_mem->leaf_nodes;
     752           0 :         uint16_t *symbol_values = cmp_mem->symbol_values;
     753             :         size_t i, j, intermediate_len;
     754           0 :         const uint8_t *data = cmp_ctx->input_bytes + cmp_ctx->input_pos;
     755           0 :         const uint8_t *prev_block = NULL;
     756           0 :         size_t remaining_size = cmp_ctx->input_size - cmp_ctx->input_pos;
     757           0 :         size_t block_end = MIN(65536, remaining_size);
     758             :         struct match match;
     759             :         int n_symbols;
     760             : 
     761           0 :         if (cmp_ctx->input_size < cmp_ctx->input_pos) {
     762           0 :                 return LZXPRESS_ERROR;
     763             :         }
     764             : 
     765           0 :         if (cmp_ctx->prev_block_pos != cmp_ctx->input_pos) {
     766           0 :                 prev_block = cmp_ctx->input_bytes + cmp_ctx->prev_block_pos;
     767           0 :         } else if (prev_hash_table != NULL) {
     768             :                 /* we've got confused! hash and block should go together */
     769           0 :                 return LZXPRESS_ERROR;
     770             :         }
     771             : 
     772             :         /*
     773             :          * leaf_nodes is used to count the symbols seen, for later Huffman
     774             :          * encoding.
     775             :          */
     776           0 :         for (i = 0; i < 512; i++) {
     777           0 :                 leaf_nodes[i] = (struct huffman_node) {
     778             :                         .symbol = i
     779             :                 };
     780             :         }
     781             : 
     782           0 :         j = 0;
     783             : 
     784           0 :         if (remaining_size < 41 || DEBUG_NO_LZ77_MATCHES) {
     785             :                 /*
     786             :                  * There is no point doing a hash table and looking for
     787             :                  * matches in this tiny block (remembering we are committed to
     788             :                  * using 32 bits, so there's a good chance we wouldn't even
     789             :                  * save a byte). The threshold of 41 matches Windows.
     790             :                  * If remaining_size < 3, we *can't* do the hash.
     791             :                  */
     792           0 :                 i = 0;
     793             :         } else {
     794             :                 /*
     795             :                  * We use 0xffff as the unset value for table, because it is
     796             :                  * not a valid match offset (and 0x0 is).
     797             :                  */
     798           0 :                 memset(hash_table, 0xff, sizeof(cmp_mem->hash_table1));
     799             : 
     800           0 :                 for (i = 0; i <= block_end - 3; i++) {
     801             :                         uint16_t code;
     802           0 :                         const uint8_t *here = data + i;
     803           0 :                         uint16_t h = three_byte_hash(here);
     804           0 :                         size_t max_len = MIN(remaining_size - i, MAX_MATCH_LENGTH);
     805           0 :                         match = lookup_match(hash_table,
     806             :                                              h,
     807             :                                              data,
     808             :                                              here,
     809             :                                              max_len);
     810             : 
     811           0 :                         if (match.there == NULL && prev_hash_table != NULL) {
     812             :                                 /*
     813             :                                  * If this is not the first block,
     814             :                                  * backreferences can look into the previous
     815             :                                  * block (but only as far as 65535 bytes, so
     816             :                                  * the end of this block cannot see the start
     817             :                                  * of the last one).
     818             :                                  */
     819           0 :                                 match = lookup_match(prev_hash_table,
     820             :                                                      h,
     821             :                                                      prev_block,
     822             :                                                      here,
     823             :                                                      remaining_size - i);
     824             :                         }
     825             : 
     826           0 :                         store_match(hash_table, h, i);
     827             : 
     828           0 :                         if (match.there == NULL) {
     829             :                                 /* add a literal and move on. */
     830           0 :                                 uint8_t c = data[i];
     831           0 :                                 leaf_nodes[c].count++;
     832           0 :                                 intermediate[j] = c;
     833           0 :                                 j++;
     834           0 :                                 continue;
     835             :                         }
     836             : 
     837             :                         /* a real match */
     838           0 :                         if (match.length <= 65538) {
     839           0 :                                 intermediate[j] = 0xffff;
     840           0 :                                 intermediate[j + 1] = match.length - 3;
     841           0 :                                 intermediate[j + 2] = here - match.there;
     842           0 :                                 j += 3;
     843             :                         } else {
     844           0 :                                 size_t m = match.length - 3;
     845           0 :                                 intermediate[j] = 0xfffe;
     846           0 :                                 intermediate[j + 1] = m & 0xffff;
     847           0 :                                 intermediate[j + 2] = m >> 16;
     848           0 :                                 intermediate[j + 3] = here - match.there;
     849           0 :                                 j += 4;
     850             :                         }
     851           0 :                         code = encode_match(match.length, here - match.there);
     852           0 :                         leaf_nodes[code].count++;
     853           0 :                         i += match.length - 1; /* `- 1` for the loop i++ */
     854             :                         /*
     855             :                          * A match can take us past the intended block length,
     856             :                          * extending the block. We don't need to do anything
     857             :                          * special for this case -- the loops will naturally
     858             :                          * do the right thing.
     859             :                          */
     860             :                 }
     861             :         }
     862             : 
     863             :         /*
     864             :          * There might be some bytes at the end.
     865             :          */
     866           0 :         for (; i < block_end; i++) {
     867           0 :                 leaf_nodes[data[i]].count++;
     868           0 :                 intermediate[j] = data[i];
     869           0 :                 j++;
     870             :         }
     871             : 
     872           0 :         if (i == remaining_size) {
     873             :                 /* add a trailing EOF marker (256) */
     874           0 :                 intermediate[j] = 0xffff;
     875           0 :                 intermediate[j + 1] = 0;
     876           0 :                 intermediate[j + 2] = 1;
     877           0 :                 j += 3;
     878           0 :                 leaf_nodes[256].count++;
     879             :         }
     880             : 
     881           0 :         intermediate_len = j;
     882             : 
     883           0 :         cmp_ctx->prev_block_pos = cmp_ctx->input_pos;
     884           0 :         cmp_ctx->input_pos += i;
     885             : 
     886             :         /* fill in the symbols table */
     887           0 :         n_symbols = generate_huffman_codes(leaf_nodes,
     888           0 :                                            cmp_mem->internal_nodes,
     889             :                                            symbol_values);
     890           0 :         if (n_symbols < 0) {
     891           0 :                 return n_symbols;
     892             :         }
     893             : 
     894           0 :         return intermediate_len;
     895             : }
     896             : 
     897             : 
     898             : 
     899           0 : static ssize_t write_huffman_table(uint16_t symbol_values[512],
     900             :                                    uint8_t *output,
     901             :                                    size_t available_size)
     902             : {
     903             :         size_t i;
     904             : 
     905           0 :         if (available_size < 256) {
     906           0 :                 return LZXPRESS_ERROR;
     907             :         }
     908             : 
     909           0 :         for (i = 0; i < 256; i++) {
     910           0 :                 uint8_t b = 0;
     911           0 :                 uint16_t even = symbol_values[i * 2];
     912           0 :                 uint16_t odd = symbol_values[i * 2 + 1];
     913           0 :                 if (even != 0) {
     914           0 :                         b = bitlen_nonzero_16(even);
     915             :                 }
     916           0 :                 if (odd != 0) {
     917           0 :                         b |= bitlen_nonzero_16(odd) << 4;
     918             :                 }
     919           0 :                 output[i] = b;
     920             :         }
     921           0 :         return i;
     922             : }
     923             : 
     924             : 
     925             : struct write_context {
     926             :         uint8_t *dest;
     927             :         size_t dest_len;
     928             :         size_t head;                 /* where lengths go */
     929             :         size_t next_code;            /* where symbol stream goes */
     930             :         size_t pending_next_code;    /* will be next_code */
     931             :         unsigned bit_len;
     932             :         uint32_t bits;
     933             : };
     934             : 
     935             : /*
     936             :  * Write out 16 bits, little-endian, for write_huffman_codes()
     937             :  *
     938             :  * As you'll notice, there's a bit to do.
     939             :  *
     940             :  * We are collecting up bits in a uint32_t, then when there are 16 of them we
     941             :  * write out a word into the stream, using a trio of offsets (wc->next_code,
     942             :  * wc->pending_next_code, and wc->head) which dance around ensuring that the
     943             :  * bitstream and the interspersed lengths are in the right places relative to
     944             :  * each other.
     945             :  */
     946             : 
     947           0 : static inline bool write_bits(struct write_context *wc,
     948             :                               uint16_t code, uint16_t length)
     949             : {
     950           0 :         wc->bits <<= length;
     951           0 :         wc->bits |= code;
     952           0 :         wc->bit_len += length;
     953           0 :         if (wc->bit_len > 16) {
     954           0 :                 uint32_t w = wc->bits >> (wc->bit_len - 16);
     955           0 :                 wc->bit_len -= 16;
     956           0 :                 if (wc->next_code + 2 > wc->dest_len ||
     957           0 :                     unlikely(wc->bit_len > 16)) {
     958           0 :                         return false;
     959             :                 }
     960           0 :                 wc->dest[wc->next_code] = w & 0xff;
     961           0 :                 wc->dest[wc->next_code + 1] = (w >> 8) & 0xff;
     962           0 :                 wc->next_code = wc->pending_next_code;
     963           0 :                 wc->pending_next_code = wc->head;
     964           0 :                 wc->head += 2;
     965             :         }
     966           0 :         return true;
     967             : }
     968             : 
     969             : 
     970           0 : static inline bool write_code(struct write_context *wc, uint16_t code)
     971             : {
     972           0 :         int code_bit_len = bitlen_nonzero_16(code);
     973           0 :         if (unlikely(code == 0)) {
     974           0 :                 return false;
     975             :         }
     976           0 :         code &= (1 << code_bit_len) - 1;
     977           0 :         return  write_bits(wc, code, code_bit_len);
     978             : }
     979             : 
     980           0 : static inline bool write_byte(struct write_context *wc, uint8_t byte)
     981             : {
     982           0 :         if (wc->head + 1 > wc->dest_len) {
     983           0 :                 return false;
     984             :         }
     985           0 :         wc->dest[wc->head] = byte;
     986           0 :         wc->head++;
     987           0 :         return true;
     988             : }
     989             : 
     990             : 
     991           0 : static inline bool write_long_len(struct write_context *wc, size_t len)
     992             : {
     993           0 :         if (len < 65535) {
     994           0 :                 if (wc->head + 3 > wc->dest_len) {
     995           0 :                         return false;
     996             :                 }
     997           0 :                 wc->dest[wc->head] = 255;
     998           0 :                 wc->dest[wc->head + 1] = len & 255;
     999           0 :                 wc->dest[wc->head + 2] = len >> 8;
    1000           0 :                 wc->head += 3;
    1001             :         } else {
    1002           0 :                 if (wc->head + 7 > wc->dest_len) {
    1003           0 :                         return false;
    1004             :                 }
    1005           0 :                 wc->dest[wc->head] = 255;
    1006           0 :                 wc->dest[wc->head + 1] = 0;
    1007           0 :                 wc->dest[wc->head + 2] = 0;
    1008           0 :                 wc->dest[wc->head + 3] = len & 255;
    1009           0 :                 wc->dest[wc->head + 4] = (len >> 8) & 255;
    1010           0 :                 wc->dest[wc->head + 5] = (len >> 16) & 255;
    1011           0 :                 wc->dest[wc->head + 6] = (len >> 24) & 255;
    1012           0 :                 wc->head += 7;
    1013             :         }
    1014           0 :         return true;
    1015             : }
    1016             : 
    1017           0 : static ssize_t write_compressed_bytes(uint16_t symbol_values[512],
    1018             :                                       uint16_t *intermediate,
    1019             :                                       size_t intermediate_len,
    1020             :                                       uint8_t *dest,
    1021             :                                       size_t dest_len)
    1022             : {
    1023             :         bool ok;
    1024             :         size_t i;
    1025             :         size_t end;
    1026           0 :         struct write_context wc = {
    1027             :                 .head = 4,
    1028             :                 .pending_next_code = 2,
    1029             :                 .dest = dest,
    1030             :                 .dest_len = dest_len
    1031             :         };
    1032           0 :         for (i = 0; i < intermediate_len; i++) {
    1033           0 :                 uint16_t c = intermediate[i];
    1034             :                 size_t len;
    1035             :                 uint16_t distance;
    1036           0 :                 uint16_t code_len = 0;
    1037           0 :                 uint16_t code_dist = 0;
    1038           0 :                 if (c < 256) {
    1039           0 :                         ok = write_code(&wc, symbol_values[c]);
    1040           0 :                         if (!ok) {
    1041           0 :                                 return LZXPRESS_ERROR;
    1042             :                         }
    1043           0 :                         continue;
    1044             :                 }
    1045             : 
    1046           0 :                 if (c == 0xfffe) {
    1047           0 :                         if (i > intermediate_len - 4) {
    1048           0 :                                 return LZXPRESS_ERROR;
    1049             :                         }
    1050             : 
    1051           0 :                         len = intermediate[i + 1];
    1052           0 :                         len |= intermediate[i + 2] << 16U;
    1053           0 :                         distance = intermediate[i + 3];
    1054           0 :                         i += 3;
    1055           0 :                 } else if (c == 0xffff) {
    1056           0 :                         if (i > intermediate_len - 3) {
    1057           0 :                                 return LZXPRESS_ERROR;
    1058             :                         }
    1059           0 :                         len = intermediate[i + 1];
    1060           0 :                         distance = intermediate[i + 2];
    1061           0 :                         i += 2;
    1062             :                 } else {
    1063           0 :                         return LZXPRESS_ERROR;
    1064             :                 }
    1065           0 :                 if (unlikely(distance == 0)) {
    1066           0 :                         return LZXPRESS_ERROR;
    1067             :                 }
    1068             :                 /* len has already had 3 subtracted */
    1069           0 :                 if (len >= 15) {
    1070             :                         /*
    1071             :                          * We are going to need to write extra length
    1072             :                          * bytes into the stream, but we don't do it
    1073             :                          * now, we do it after the code has been
    1074             :                          * written (and before the distance bits).
    1075             :                          */
    1076           0 :                         code_len = 15;
    1077             :                 } else {
    1078           0 :                         code_len = len;
    1079             :                 }
    1080           0 :                 code_dist = bitlen_nonzero_16(distance);
    1081           0 :                 c = 256 | (code_dist << 4) | code_len;
    1082           0 :                 if (c > 511) {
    1083           0 :                         return LZXPRESS_ERROR;
    1084             :                 }
    1085             : 
    1086           0 :                 ok = write_code(&wc, symbol_values[c]);
    1087           0 :                 if (!ok) {
    1088           0 :                         return LZXPRESS_ERROR;
    1089             :                 }
    1090             : 
    1091           0 :                 if (code_len == 15) {
    1092           0 :                         if (len >= 270) {
    1093           0 :                                 ok = write_long_len(&wc, len);
    1094             :                         } else {
    1095           0 :                                 ok = write_byte(&wc, len - 15);
    1096             :                         }
    1097           0 :                         if (! ok) {
    1098           0 :                                 return LZXPRESS_ERROR;
    1099             :                         }
    1100             :                 }
    1101           0 :                 if (code_dist != 0) {
    1102           0 :                         uint16_t dist_bits = distance - (1 << code_dist);
    1103           0 :                         ok = write_bits(&wc, dist_bits, code_dist);
    1104           0 :                         if (!ok) {
    1105           0 :                                 return LZXPRESS_ERROR;
    1106             :                         }
    1107             :                 }
    1108             :         }
    1109             :         /*
    1110             :          * There are some intricacies around flushing the bits and returning
    1111             :          * the length.
    1112             :          *
    1113             :          * If the returned length is not exactly right and there is another
    1114             :          * block, that block will read its huffman table from the wrong place,
    1115             :          * and have all the symbol codes out by a multiple of 4.
    1116             :          */
    1117           0 :         end = wc.head;
    1118           0 :         if (wc.bit_len == 0) {
    1119           0 :                 end -= 2;
    1120             :         }
    1121           0 :         ok = write_bits(&wc, 0, 16 - wc.bit_len);
    1122           0 :         if (!ok) {
    1123           0 :                 return LZXPRESS_ERROR;
    1124             :         }
    1125           0 :         for (i = 0; i < 2; i++) {
    1126             :                 /*
    1127             :                  * Flush out the bits with zeroes. It doesn't matter if we do
    1128             :                  * a round too many, as we have buffer space, and have already
    1129             :                  * determined the returned length (end).
    1130             :                  */
    1131           0 :                 ok = write_bits(&wc, 0, 16);
    1132           0 :                 if (!ok) {
    1133           0 :                         return LZXPRESS_ERROR;
    1134             :                 }
    1135             :         }
    1136           0 :         return end;
    1137             : }
    1138             : 
    1139             : 
    1140           0 : static ssize_t lzx_huffman_compress_block(struct lzxhuff_compressor_context *cmp_ctx,
    1141             :                                           struct lzxhuff_compressor_mem *cmp_mem,
    1142             :                                           size_t block_no)
    1143             : {
    1144             :         ssize_t intermediate_size;
    1145           0 :         uint16_t *hash_table = NULL;
    1146           0 :         uint16_t *back_window_hash_table = NULL;
    1147             :         ssize_t bytes_written;
    1148             : 
    1149           0 :         if (cmp_ctx->available_size - cmp_ctx->output_pos < 260) {
    1150             :                 /* huffman block + 4 bytes */
    1151           0 :                 return LZXPRESS_ERROR;
    1152             :         }
    1153             : 
    1154             :         /*
    1155             :          * For LZ77 compression, we keep a hash table for the previous block,
    1156             :          * via alternation after the first block.
    1157             :          *
    1158             :          * LZ77 writes into the intermediate buffer in the cmp_mem context.
    1159             :          */
    1160           0 :         if (block_no == 0) {
    1161           0 :                 hash_table = cmp_mem->hash_table1;
    1162           0 :                 back_window_hash_table = NULL;
    1163           0 :         } else if (block_no & 1) {
    1164           0 :                 hash_table = cmp_mem->hash_table2;
    1165           0 :                 back_window_hash_table = cmp_mem->hash_table1;
    1166             :         } else {
    1167           0 :                 hash_table = cmp_mem->hash_table1;
    1168           0 :                 back_window_hash_table = cmp_mem->hash_table2;
    1169             :         }
    1170             : 
    1171           0 :         intermediate_size = lz77_encode_block(cmp_ctx,
    1172             :                                               cmp_mem,
    1173             :                                               hash_table,
    1174             :                                               back_window_hash_table);
    1175             : 
    1176           0 :         if (intermediate_size < 0) {
    1177           0 :                 return intermediate_size;
    1178             :         }
    1179             : 
    1180             :         /*
    1181             :          * Write the 256 byte Huffman table, based on the counts gained in
    1182             :          * LZ77 phase.
    1183             :          */
    1184           0 :         bytes_written = write_huffman_table(
    1185           0 :                 cmp_mem->symbol_values,
    1186           0 :                 cmp_ctx->output + cmp_ctx->output_pos,
    1187           0 :                 cmp_ctx->available_size - cmp_ctx->output_pos);
    1188             : 
    1189           0 :         if (bytes_written != 256) {
    1190           0 :                 return LZXPRESS_ERROR;
    1191             :         }
    1192           0 :         cmp_ctx->output_pos += 256;
    1193             : 
    1194             :         /*
    1195             :          * Write the compressed bytes using the LZ77 matches and Huffman codes
    1196             :          * worked out in the previous steps.
    1197             :          */
    1198           0 :         bytes_written = write_compressed_bytes(
    1199           0 :                 cmp_mem->symbol_values,
    1200           0 :                 cmp_mem->intermediate,
    1201             :                 intermediate_size,
    1202           0 :                 cmp_ctx->output + cmp_ctx->output_pos,
    1203           0 :                 cmp_ctx->available_size - cmp_ctx->output_pos);
    1204             : 
    1205           0 :         if (bytes_written < 0) {
    1206           0 :                 return bytes_written;
    1207             :         }
    1208             : 
    1209           0 :         cmp_ctx->output_pos += bytes_written;
    1210           0 :         return bytes_written;
    1211             : }
    1212             : 
    1213             : 
    1214             : /*
    1215             :  * lzxpress_huffman_compress_talloc()
    1216             :  *
    1217             :  * This is the convenience function that allocates the compressor context and
    1218             :  * output memory for you. The return value is the number of bytes written to
    1219             :  * the location indicated by the output pointer.
    1220             :  *
    1221             :  * The maximum input_size is effectively around 227MB due to the need to guess
    1222             :  * an upper bound on the output size that hits an internal limitation in
    1223             :  * talloc.
    1224             :  *
    1225             :  * @param mem_ctx      TALLOC_CTX parent for the compressed buffer.
    1226             :  * @param input_bytes  memory to be compressed.
    1227             :  * @param input_size   length of the input buffer.
    1228             :  * @param output       destination pointer for the compressed data.
    1229             :  *
    1230             :  * @return the number of bytes written or -1 on error.
    1231             :  */
    1232             : 
    1233           0 : ssize_t lzxpress_huffman_compress_talloc(TALLOC_CTX *mem_ctx,
    1234             :                                          const uint8_t *input_bytes,
    1235             :                                          size_t input_size,
    1236             :                                          uint8_t **output)
    1237             : {
    1238           0 :         struct lzxhuff_compressor_mem *cmp = NULL;
    1239             :         /*
    1240             :          * In the worst case, the output size should be about the same as the
    1241             :          * input size, plus the 256 byte header per 64k block. We aim for
    1242             :          * ample, but within the order of magnitude.
    1243             :          */
    1244           0 :         size_t alloc_size = input_size + (input_size / 8) + 270;
    1245             :         ssize_t output_size;
    1246             : 
    1247           0 :         *output = talloc_array(mem_ctx, uint8_t, alloc_size);
    1248           0 :         if (*output == NULL) {
    1249           0 :                 return LZXPRESS_ERROR;
    1250             :         }
    1251             : 
    1252           0 :         cmp = talloc(mem_ctx, struct lzxhuff_compressor_mem);
    1253           0 :         if (cmp == NULL) {
    1254           0 :                 TALLOC_FREE(*output);
    1255           0 :                 return LZXPRESS_ERROR;
    1256             :         }
    1257             : 
    1258           0 :         output_size = lzxpress_huffman_compress(cmp,
    1259             :                                                 input_bytes,
    1260             :                                                 input_size,
    1261             :                                                 *output,
    1262             :                                                 alloc_size);
    1263             : 
    1264           0 :         talloc_free(cmp);
    1265             : 
    1266           0 :         if (output_size < 0) {
    1267           0 :                 TALLOC_FREE(*output);
    1268           0 :                 return LZXPRESS_ERROR;
    1269             :         }
    1270             : 
    1271           0 :         *output = talloc_realloc(mem_ctx, *output, uint8_t, output_size);
    1272           0 :         if (*output == NULL) {
    1273           0 :                 return LZXPRESS_ERROR;
    1274             :         }
    1275             : 
    1276           0 :         return output_size;
    1277             : }
    1278             : 
    1279             : /*
    1280             :  * lzxpress_huffman_compress()
    1281             :  *
    1282             :  * This is the inconvenience function, slightly faster and fiddlier than
    1283             :  * lzxpress_huffman_compress_talloc().
    1284             :  *
    1285             :  * To use this, you need to have allocated (but not initialised) a `struct
    1286             :  * lzxhuff_compressor_context`, and an output buffer. If the buffer is not big
    1287             :  * enough (per `output_size`), you'll get a negative return value, otherwise
    1288             :  * the number of bytes actually consumed, which will always be at least 260.
    1289             :  *
    1290             :  * The `struct lzxhuff_compressor_context` is reusable -- it is basically a
    1291             :  * collection of uninitialised memory buffers. The total size is less than
    1292             :  * 150k, so stack allocation is plausible.
    1293             :  *
    1294             :  * input_size and available_size are limited to the minimum of UINT32_MAX and
    1295             :  * SSIZE_MAX. On 64 bit machines that will be UINT32_MAX, or 4GB.
    1296             :  *
    1297             :  * @param cmp_mem         a struct lzxhuff_compressor_mem.
    1298             :  * @param input_bytes     memory to be compressed.
    1299             :  * @param input_size      length of the input buffer.
    1300             :  * @param output          destination for the compressed data.
    1301             :  * @param available_size  allocated output bytes.
    1302             :  *
    1303             :  * @return the number of bytes written or -1 on error.
    1304             :  */
    1305           0 : ssize_t lzxpress_huffman_compress(struct lzxhuff_compressor_mem *cmp_mem,
    1306             :                                   const uint8_t *input_bytes,
    1307             :                                   size_t input_size,
    1308             :                                   uint8_t *output,
    1309             :                                   size_t available_size)
    1310             : {
    1311           0 :         size_t i = 0;
    1312           0 :         struct lzxhuff_compressor_context cmp_ctx = {
    1313             :                 .input_bytes = input_bytes,
    1314             :                 .input_size = input_size,
    1315             :                 .input_pos = 0,
    1316             :                 .prev_block_pos = 0,
    1317             :                 .output = output,
    1318             :                 .available_size = available_size,
    1319             :                 .output_pos = 0
    1320             :         };
    1321             : 
    1322           0 :         if (input_size == 0) {
    1323             :                 /*
    1324             :                  * We can't deal with this for a number of reasons (e.g. it
    1325             :                  * breaks the Huffman tree), and the output will be infinitely
    1326             :                  * bigger than the input. The caller needs to go and think
    1327             :                  * about what they're trying to do here.
    1328             :                  */
    1329           0 :                 return LZXPRESS_ERROR;
    1330             :         }
    1331             : 
    1332           0 :         if (input_size > SSIZE_MAX ||
    1333           0 :             input_size > UINT32_MAX ||
    1334           0 :             available_size > SSIZE_MAX ||
    1335           0 :             available_size > UINT32_MAX ||
    1336             :             available_size == 0) {
    1337             :                 /*
    1338             :                  * We use negative ssize_t to return errors, which is limiting
    1339             :                  * on 32 bit machines; otherwise we adhere to Microsoft's 4GB
    1340             :                  * limit.
    1341             :                  *
    1342             :                  * lzxpress_huffman_compress_talloc() will not get this far,
    1343             :                  * having already have failed on talloc's 256 MB limit.
    1344             :                  */
    1345           0 :                 return LZXPRESS_ERROR;
    1346             :         }
    1347             : 
    1348           0 :         if (cmp_mem == NULL ||
    1349           0 :             output == NULL ||
    1350             :             input_bytes == NULL) {
    1351           0 :                 return LZXPRESS_ERROR;
    1352             :         }
    1353             : 
    1354           0 :         while (cmp_ctx.input_pos < cmp_ctx.input_size) {
    1355             :                 ssize_t ret;
    1356           0 :                 ret = lzx_huffman_compress_block(&cmp_ctx,
    1357             :                                                  cmp_mem,
    1358             :                                                  i);
    1359           0 :                 if (ret < 0) {
    1360           0 :                         return ret;
    1361             :                 }
    1362           0 :                 i++;
    1363             :         }
    1364             : 
    1365           0 :         return cmp_ctx.output_pos;
    1366             : }
    1367             : 
    1368           0 : static void debug_tree_codes(struct bitstream *input)
    1369             : {
    1370             :         /*
    1371             :          */
    1372           0 :         size_t head = 0;
    1373           0 :         size_t tail = 2;
    1374           0 :         size_t ffff_count = 0;
    1375             :         struct q {
    1376             :                 uint16_t tree_code;
    1377             :                 uint16_t code_code;
    1378             :         };
    1379             :         struct q queue[65536];
    1380             :         char bits[17];
    1381           0 :         uint16_t *t = input->table;
    1382           0 :         queue[0].tree_code = 1;
    1383           0 :         queue[0].code_code = 2;
    1384           0 :         queue[1].tree_code = 2;
    1385           0 :         queue[1].code_code = 3;
    1386           0 :         while (head < tail) {
    1387           0 :                 struct q q = queue[head];
    1388           0 :                 uint16_t x = t[q.tree_code];
    1389           0 :                 if (x != 0xffff) {
    1390             :                         int k;
    1391           0 :                         uint16_t j = q.code_code;
    1392           0 :                         size_t offset = bitlen_nonzero_16(j) - 1;
    1393           0 :                         if (unlikely(j == 0)) {
    1394           0 :                                 DBG("BROKEN code is 0!\n");
    1395           0 :                                 return;
    1396             :                         }
    1397             : 
    1398           0 :                         for (k = 0; k <= offset; k++) {
    1399           0 :                                 bool b = (j >> (offset - k)) & 1;
    1400           0 :                                 bits[k] = b ? '1' : '0';
    1401             :                         }
    1402           0 :                         bits[k] = 0;
    1403           0 :                         DBG("%03x   %s\n", x & 511, bits);
    1404           0 :                         head++;
    1405           0 :                         continue;
    1406             :                 }
    1407           0 :                 ffff_count++;
    1408           0 :                 queue[tail].tree_code = q.tree_code * 2 + 1;
    1409           0 :                 queue[tail].code_code = q.code_code * 2;
    1410           0 :                 tail++;
    1411           0 :                 queue[tail].tree_code = q.tree_code * 2 + 1 + 1;
    1412           0 :                 queue[tail].code_code = q.code_code * 2 + 1;
    1413           0 :                 tail++;
    1414           0 :                 head++;
    1415             :         }
    1416           0 :         DBG("0xffff count: %zu\n", ffff_count);
    1417             : }
    1418             : 
    1419             : /**
    1420             :  * Determines the sort order of one prefix_code_symbol relative to another
    1421             :  */
    1422           0 : static int compare_uint16(const uint16_t *a, const uint16_t *b)
    1423             : {
    1424           0 :         if (*a < *b) {
    1425           0 :                 return -1;
    1426             :         }
    1427           0 :         if (*a > *b) {
    1428           0 :                 return 1;
    1429             :         }
    1430           0 :         return 0;
    1431             : }
    1432             : 
    1433             : 
    1434           0 : static bool fill_decomp_table(struct bitstream *input)
    1435             : {
    1436             :         /*
    1437             :          * There are 512 symbols, each encoded in 4 bits, which indicates
    1438             :          * their depth in the Huffman tree. The even numbers get the lower
    1439             :          * nibble of each byte, so that the byte hex values look backwards
    1440             :          * (i.e. 0xab encodes b then a). These are allocated Huffman codes in
    1441             :          * order of appearance, per depth.
    1442             :          *
    1443             :          * For example, if the first two bytes were:
    1444             :          *
    1445             :          * 0x23 0x53
    1446             :          *
    1447             :          * the first four codes have the lengths 3, 2, 3, 5.
    1448             :          * Let's call them A, B, C, D.
    1449             :          *
    1450             :          * Suppose there is no other codeword with length 1 (which is
    1451             :          * necessarily true in this example) or 2, but there might be others
    1452             :          * of length 3 or 4. Then we can say this about the codes:
    1453             :          *
    1454             :          *        _ --*--_
    1455             :          *      /          \
    1456             :          *     0           1
    1457             :          *    / \         / \
    1458             :          *   0   1       0   1
    1459             :          *  B    |\     / \  |\
    1460             :          *       0 1   0   1 0 1
    1461             :          *       A C   |\ /| | |\
    1462             :          *
    1463             :          * pos bits  code
    1464             :          * A    3    010
    1465             :          * B    2    00
    1466             :          * C    3    011
    1467             :          * D    5    1????
    1468             :          *
    1469             :          * B has the shortest code, so takes the leftmost branch, 00. That
    1470             :          * ends the branch -- nothing else can start with 00. There are no
    1471             :          * more 2s, so we look at the 3s, starting as far left as possible. So
    1472             :          * A takes 010 and C takes 011. That means everything else has to
    1473             :          * start with 1xx. We don't know how many codewords of length 3 or 4
    1474             :          * there are; if there are none, D would end up with 10000, the
    1475             :          * leftmost available code of length 5. If the compressor is any good,
    1476             :          * there should be no unused leaf nodes left dangling at the end.
    1477             :          *
    1478             :          * (this is "Canonical Huffman Coding").
    1479             :          *
    1480             :          *
    1481             :          * But what symbols do these codes actually stand for?
    1482             :          * --------------------------------------------------
    1483             :          *
    1484             :          * Good question. The first 256 codes stand for the corresponding
    1485             :          * literal bytes. The codes from 256 to 511 stand for LZ77 matches,
    1486             :          * which have a distance and a length, encoded in a strange way that
    1487             :          * isn't entirely the purview of this function.
    1488             :          *
    1489             :          * What does the value 0 mean?
    1490             :          * ---------------------------
    1491             :          *
    1492             :          * The code does not occur. For example, if the next byte in the
    1493             :          * example above was 0x07, that would give the byte 0x04 a 7-long
    1494             :          * code, and no code to the 0x05 byte, which means we there is no way
    1495             :          * we going to see a 5 in the decoded stream.
    1496             :          *
    1497             :          * Isn't LZ77 + Huffman what zip/gzip/zlib do?
    1498             :          * -------------------------------------------
    1499             :          *
    1500             :          * Yes, DEFLATE is LZ77 + Huffman, but the details are quite different.
    1501             :          */
    1502             :         uint16_t symbols[512];
    1503             :         uint16_t sort_mem[512];
    1504             :         size_t i, n_symbols;
    1505             :         ssize_t code;
    1506             :         uint16_t len, prev_len;
    1507           0 :         const uint8_t *table_bytes = input->bytes + input->byte_pos;
    1508             : 
    1509           0 :         if (input->byte_pos + 260 > input->byte_size) {
    1510           0 :                 return false;
    1511             :         }
    1512             : 
    1513           0 :         n_symbols = 0;
    1514           0 :         for (i = 0; i < 256; i++) {
    1515           0 :                 uint16_t even = table_bytes[i] & 15;
    1516           0 :                 uint16_t odd = table_bytes[i] >> 4;
    1517           0 :                 if (even != 0) {
    1518           0 :                         symbols[n_symbols] = (even << 9) + i * 2;
    1519           0 :                         n_symbols++;
    1520             :                 }
    1521           0 :                 if (odd != 0) {
    1522           0 :                         symbols[n_symbols] = (odd << 9) + i * 2 + 1;
    1523           0 :                         n_symbols++;
    1524             :                 }
    1525             :         }
    1526           0 :         input->byte_pos += 256;
    1527           0 :         if (n_symbols == 0) {
    1528           0 :                 return false;
    1529             :         }
    1530             : 
    1531           0 :         stable_sort(symbols, sort_mem, n_symbols, sizeof(uint16_t),
    1532             :                     (samba_compare_fn_t)compare_uint16);
    1533             : 
    1534             :         /*
    1535             :          * we're using an implicit binary tree, as you'd see in a heap.
    1536             :          * table[0] = unused
    1537             :          * table[1] = '0'
    1538             :          * table[2] = '1'
    1539             :          * table[3] = '00'     <-- '00' and '01' are children of '0'
    1540             :          * table[4] = '01'     <-- '0' is [0], children are [0 * 2 + {1,2}]
    1541             :          * table[5] = '10'
    1542             :          * table[6] = '11'
    1543             :          * table[7] = '000'
    1544             :          * table[8] = '001'
    1545             :          * table[9] = '010'
    1546             :          * table[10]= '011'
    1547             :          * table[11]= '100
    1548             :          *'
    1549             :          * table[1 << n - 1] = '0' * n
    1550             :          * table[1 << n - 1 + x] = n-bit wide x (left padded with '0')
    1551             :          * table[1 << n - 2] = '1' * (n - 1)
    1552             :          *
    1553             :          * table[i]->left =  table[i*2 + 1]
    1554             :          * table[i]->right = table[i*2 + 2]
    1555             :          * table[0xffff] = unused (16 '0's, max len is 15)
    1556             :          *
    1557             :          * therefore e.g. table[70] = table[64     - 1 + 7]
    1558             :          *                          = table[1 << 6 - 1 + 7]
    1559             :          *                          = '000111' (binary 7, widened to 6 bits)
    1560             :          *
    1561             :          *   and if '000111' is a code,
    1562             :          *   '00011', '0001', '000', '00', '0' are unavailable prefixes.
    1563             :          *       34      16      7     3    1  are their indices
    1564             :          *   and (i - 1) >> 1 is the path back from 70 through these.
    1565             :          *
    1566             :          * the lookup is
    1567             :          *
    1568             :          * 1 start with i = 0
    1569             :          * 2 extract a symbol bit (i = (i << 1) + bit + 1)
    1570             :          * 3 is table[i] == 0xffff?
    1571             :          * 4  yes -- goto 2
    1572             :          * 4  table[i] & 511 is the symbol, stop
    1573             :          *
    1574             :          * and the construction (here) is sort of the reverse.
    1575             :          *
    1576             :          * Most of this table is free space that can never be reached, and
    1577             :          * most of the activity is at the beginning (since all codes start
    1578             :          * there, and by design the shortest codes are the most common).
    1579             :          */
    1580           0 :         for (i = 0; i < 32; i++) {
    1581             :                 /* prefill the table head */
    1582           0 :                 input->table[i] = 0xffff;
    1583             :         }
    1584           0 :         code = -1;
    1585           0 :         prev_len = 0;
    1586           0 :         for (i = 0; i < n_symbols; i++) {
    1587           0 :                 uint16_t s = symbols[i];
    1588             :                 uint16_t prefix;
    1589           0 :                 len = (s >> 9) & 15;
    1590           0 :                 s &= 511;
    1591           0 :                 code++;
    1592           0 :                 while (len != prev_len) {
    1593           0 :                         code <<= 1;
    1594           0 :                         code++;
    1595           0 :                         prev_len++;
    1596             :                 }
    1597             : 
    1598           0 :                 if (code >= 65535) {
    1599           0 :                         return false;
    1600             :                 }
    1601           0 :                 input->table[code] = s;
    1602           0 :                 for(prefix = (code - 1) >> 1;
    1603           0 :                     prefix > 31;
    1604           0 :                     prefix = (prefix - 1) >> 1) {
    1605           0 :                         input->table[prefix] = 0xffff;
    1606             :                 }
    1607             :         }
    1608           0 :         if (CHECK_DEBUGLVL(10)) {
    1609           0 :                 debug_tree_codes(input);
    1610             :         }
    1611             : 
    1612             :         /*
    1613             :          * check that the last code encodes 11111..., with right number of
    1614             :          * ones, pointing to the right symbol -- otherwise we have a dangling
    1615             :          * uninitialised symbol.
    1616             :          */
    1617           0 :         if (code != (1 << (len + 1)) - 2) {
    1618           0 :                 return false;
    1619             :         }
    1620           0 :         return true;
    1621             : }
    1622             : 
    1623             : 
    1624             : #define CHECK_READ_32(dest)                                       \
    1625             :         do {                                                      \
    1626             :                 if (input->byte_pos + 4 > input->byte_size) {     \
    1627             :                         return LZXPRESS_ERROR;                     \
    1628             :                 }                                                  \
    1629             :                 dest = PULL_LE_U32(input->bytes, input->byte_pos); \
    1630             :                 input->byte_pos += 4;                                   \
    1631             :         } while (0)
    1632             : 
    1633             : #define CHECK_READ_16(dest)                                       \
    1634             :         do {                                                      \
    1635             :                 if (input->byte_pos + 2 > input->byte_size) {     \
    1636             :                         return LZXPRESS_ERROR;                     \
    1637             :                 }                                                  \
    1638             :                 dest = PULL_LE_U16(input->bytes, input->byte_pos); \
    1639             :                 input->byte_pos += 2;                                   \
    1640             :         } while (0)
    1641             : 
    1642             : #define CHECK_READ_8(dest) \
    1643             :         do {                                                            \
    1644             :                 if (input->byte_pos >= input->byte_size) {             \
    1645             :                         return LZXPRESS_ERROR;                          \
    1646             :                 }                                                       \
    1647             :                 dest = PULL_LE_U8(input->bytes, input->byte_pos); \
    1648             :                 input->byte_pos++;                                   \
    1649             :         } while(0)
    1650             : 
    1651             : 
    1652           0 : static inline ssize_t pull_bits(struct bitstream *input)
    1653             : {
    1654           0 :         if (input->byte_pos + 1 < input->byte_size) {
    1655             :                 uint16_t tmp;
    1656           0 :                 CHECK_READ_16(tmp);
    1657           0 :                 input->remaining_bits += 16;
    1658           0 :                 input->bits <<= 16;
    1659           0 :                 input->bits |= tmp;
    1660           0 :         } else if (input->byte_pos < input->byte_size) {
    1661             :                 uint8_t tmp;
    1662           0 :                 CHECK_READ_8(tmp);
    1663           0 :                 input->remaining_bits += 8;
    1664           0 :                 input->bits <<= 8;
    1665           0 :                 input->bits |= tmp;
    1666             :         } else {
    1667           0 :                 return LZXPRESS_ERROR;
    1668             :         }
    1669           0 :         return 0;
    1670             : }
    1671             : 
    1672             : 
    1673             : /*
    1674             :  * Decompress a block. The actual decompressed size is returned (or -1 on
    1675             :  * error). The putative block length is 64k (or shorter, if the message ends
    1676             :  * first), but a match can run over the end, extending the block. That's why
    1677             :  * we need the overall output size as well as the block size. A match encoded
    1678             :  * in this block can point back to previous blocks, but not before the
    1679             :  * beginning of the message, so we also need the previously decoded size.
    1680             :  *
    1681             :  * The compressed block will have 256 bytes for the Huffman table, and at
    1682             :  * least 4 bytes of (possibly padded) encoded values.
    1683             :  */
    1684           0 : static ssize_t lzx_huffman_decompress_block(struct bitstream *input,
    1685             :                                             uint8_t *output,
    1686             :                                             size_t block_size,
    1687             :                                             size_t output_size,
    1688             :                                             size_t previous_size)
    1689             : {
    1690           0 :         size_t output_pos = 0;
    1691             :         uint16_t symbol;
    1692             :         size_t index;
    1693           0 :         uint16_t distance_bits_wanted = 0;
    1694           0 :         size_t distance = 0;
    1695           0 :         size_t length = 0;
    1696             :         bool ok;
    1697             :         uint32_t tmp;
    1698           0 :         bool seen_eof_marker = false;
    1699             : 
    1700           0 :         ok = fill_decomp_table(input);
    1701           0 :         if (! ok) {
    1702           0 :                 return LZXPRESS_ERROR;
    1703             :         }
    1704           0 :         if (CHECK_DEBUGLVL(10) || DEBUG_HUFFMAN_TREE) {
    1705           0 :                 debug_huffman_tree_from_table(input->table);
    1706             :         }
    1707             :         /*
    1708             :          * Always read 32 bits at the start, even if we don't need them.
    1709             :          */
    1710           0 :         CHECK_READ_16(tmp);
    1711           0 :         CHECK_READ_16(input->bits);
    1712           0 :         input->bits |= tmp << 16;
    1713           0 :         input->remaining_bits = 32;
    1714             : 
    1715             :         /*
    1716             :          * This loop iterates over individual *bits*. These are read from
    1717             :          * little-endian 16 bit words, most significant bit first.
    1718             :          *
    1719             :          * At points in the bitstream, the following are possible:
    1720             :          *
    1721             :          * # the source word is empty and needs to be refilled from the input
    1722             :          *    stream.
    1723             :          * # an incomplete codeword is being extended.
    1724             :          * # a codeword is resolved, either as a literal or a match.
    1725             :          * # a literal is written.
    1726             :          * # a match is collecting distance bits.
    1727             :          * # the output stream is copied, as specified by a match.
    1728             :          * # input bytes are read for match lengths.
    1729             :          *
    1730             :          * Note that we *don't* specifically check for the EOF marker (symbol
    1731             :          * 256) in this loop, because the a precondition for stopping for the
    1732             :          * EOF marker is that the output buffer is full (otherwise, you
    1733             :          * wouldn't know which 256 is EOF, rather than an actual symbol), and
    1734             :          * we *always* want to stop when the buffer is full. So we work out if
    1735             :          * there is an EOF in in another loop after we stop writing.
    1736             :          */
    1737             : 
    1738           0 :         index = 0;
    1739           0 :         while (output_pos < block_size) {
    1740             :                 uint16_t b;
    1741           0 :                 if (input->remaining_bits == 16) {
    1742           0 :                         ssize_t ret = pull_bits(input);
    1743           0 :                         if (ret) {
    1744           0 :                                 return ret;
    1745             :                         }
    1746             :                 }
    1747           0 :                 input->remaining_bits--;
    1748             : 
    1749           0 :                 b = (input->bits >> input->remaining_bits) & 1;
    1750           0 :                 if (length == 0) {
    1751             :                         /* not in a match; pulling a codeword */
    1752           0 :                         index <<= 1;
    1753           0 :                         index += b + 1;
    1754           0 :                         if (input->table[index] == 0xffff) {
    1755             :                                 /* incomplete codeword, the common case */
    1756           0 :                                 continue;
    1757             :                         }
    1758             :                         /* found the symbol, reset the code string */
    1759           0 :                         symbol = input->table[index] & 511;
    1760           0 :                         index = 0;
    1761           0 :                         if (symbol < 256) {
    1762             :                                 /* a literal, the easy case */
    1763           0 :                                 output[output_pos] = symbol;
    1764           0 :                                 output_pos++;
    1765           0 :                                 continue;
    1766             :                         }
    1767             : 
    1768             :                         /* the beginning of a match */
    1769           0 :                         distance_bits_wanted = (symbol >> 4) & 15;
    1770           0 :                         distance = 1 << distance_bits_wanted;
    1771           0 :                         length = symbol & 15;
    1772           0 :                         if (length == 15) {
    1773           0 :                                 CHECK_READ_8(tmp);
    1774           0 :                                 length += tmp;
    1775           0 :                                 if (length == 255 + 15) {
    1776             :                                         /*
    1777             :                                          * note, we discard (don't add) the
    1778             :                                          * length so far.
    1779             :                                          */
    1780           0 :                                         CHECK_READ_16(length);
    1781           0 :                                         if (length == 0) {
    1782           0 :                                                 CHECK_READ_32(length);
    1783             :                                         }
    1784             :                                 }
    1785             :                         }
    1786           0 :                         length += 3;
    1787             :                 } else {
    1788             :                         /* we are pulling extra distance bits */
    1789           0 :                         distance_bits_wanted--;
    1790           0 :                         distance |= b << distance_bits_wanted;
    1791             :                 }
    1792             : 
    1793           0 :                 if (distance_bits_wanted == 0) {
    1794             :                         /*
    1795             :                          * We have a complete match, and it is time to do the
    1796             :                          * copy (byte by byte, because the ranges can overlap,
    1797             :                          * and we might need to copy bytes we just copied in).
    1798             :                          *
    1799             :                          * It is possible that this match will extend beyond
    1800             :                          * the end of the expected block. That's fine, so long
    1801             :                          * as it doesn't extend past the total output size.
    1802             :                          */
    1803             :                         size_t i;
    1804           0 :                         size_t end = output_pos + length;
    1805           0 :                         uint8_t *here = output + output_pos;
    1806           0 :                         uint8_t *there = here - distance;
    1807           0 :                         if (end > output_size ||
    1808           0 :                             previous_size + output_pos < distance ||
    1809           0 :                             unlikely(end < output_pos || there > here)) {
    1810           0 :                                 return LZXPRESS_ERROR;
    1811             :                         }
    1812           0 :                         for (i = 0; i < length; i++) {
    1813           0 :                                 here[i] = there[i];
    1814             :                         }
    1815           0 :                         output_pos += length;
    1816           0 :                         distance = 0;
    1817           0 :                         length = 0;
    1818             :                 }
    1819             :         }
    1820             : 
    1821           0 :         if (length != 0 || index != 0) {
    1822             :                 /* it seems like we've hit an early end, mid-code */
    1823           0 :                 return LZXPRESS_ERROR;
    1824             :         }
    1825             : 
    1826           0 :         if (input->byte_pos + 256 < input->byte_size) {
    1827             :                 /*
    1828             :                  * This block is over, but it clearly isn't the last block, so
    1829             :                  * we don't want to look for the EOF.
    1830             :                  */
    1831           0 :                 return output_pos;
    1832             :         }
    1833             :         /*
    1834             :          * We won't write any more, but we try to read some more to make sure
    1835             :          * we're finishing in a good place. That means we want to see a 256
    1836             :          * symbol and then some number of zeroes, possibly zero, but as many
    1837             :          * as 32.
    1838             :          *
    1839             :          * In this we are perhaps a bit stricter than Windows, which
    1840             :          * apparently does not insist on the EOF marker, nor on a lack of
    1841             :          * trailing bytes.
    1842             :          */
    1843           0 :         while (true) {
    1844             :                 uint16_t b;
    1845           0 :                 if (input->remaining_bits == 16) {
    1846             :                         ssize_t ret;
    1847           0 :                         if (input->byte_pos == input->byte_size) {
    1848             :                                 /* FIN */
    1849           0 :                                 break;
    1850             :                         }
    1851           0 :                         ret = pull_bits(input);
    1852           0 :                         if (ret) {
    1853           0 :                                 return ret;
    1854             :                         }
    1855             :                 }
    1856           0 :                 input->remaining_bits--;
    1857           0 :                 b = (input->bits >> input->remaining_bits) & 1;
    1858           0 :                 if (seen_eof_marker) {
    1859             :                         /*
    1860             :                          * we have read an EOF symbols. Now we just want to
    1861             :                          * see zeroes.
    1862             :                          */
    1863           0 :                         if (b != 0) {
    1864           0 :                                 return LZXPRESS_ERROR;
    1865             :                         }
    1866           0 :                         continue;
    1867             :                 }
    1868             : 
    1869             :                 /* we're pulling in a symbol, which had better be 256 */
    1870           0 :                 index <<= 1;
    1871           0 :                 index += b + 1;
    1872           0 :                 if (input->table[index] == 0xffff) {
    1873           0 :                         continue;
    1874             :                 }
    1875             : 
    1876           0 :                 symbol = input->table[index] & 511;
    1877           0 :                 if (symbol != 256) {
    1878           0 :                         return LZXPRESS_ERROR;
    1879             :                 }
    1880           0 :                 seen_eof_marker = true;
    1881           0 :                 continue;
    1882             :         }
    1883             : 
    1884           0 :         if (! seen_eof_marker) {
    1885           0 :                 return LZXPRESS_ERROR;
    1886             :         }
    1887             : 
    1888           0 :         return output_pos;
    1889             : }
    1890             : 
    1891           0 : static ssize_t lzxpress_huffman_decompress_internal(struct bitstream *input,
    1892             :                                                     uint8_t *output,
    1893             :                                                     size_t output_size)
    1894             : {
    1895           0 :         size_t output_pos = 0;
    1896             : 
    1897           0 :         if (input->byte_size < 260) {
    1898           0 :                 return LZXPRESS_ERROR;
    1899             :         }
    1900             : 
    1901           0 :         while (input->byte_pos < input->byte_size) {
    1902             :                 ssize_t block_output_pos;
    1903             :                 ssize_t block_output_size;
    1904           0 :                 size_t remaining_output_size = output_size - output_pos;
    1905             : 
    1906           0 :                 block_output_size = MIN(65536, remaining_output_size);
    1907             : 
    1908           0 :                 block_output_pos = lzx_huffman_decompress_block(
    1909             :                         input,
    1910             :                         output + output_pos,
    1911             :                         block_output_size,
    1912             :                         remaining_output_size,
    1913             :                         output_pos);
    1914             : 
    1915           0 :                 if (block_output_pos < block_output_size) {
    1916           0 :                         return LZXPRESS_ERROR;
    1917             :                 }
    1918           0 :                 output_pos += block_output_pos;
    1919           0 :                 if (output_pos > output_size) {
    1920             :                         /* not expecting to get here. */
    1921           0 :                         return LZXPRESS_ERROR;
    1922             :                 }
    1923             :         }
    1924             : 
    1925           0 :         if (input->byte_pos != input->byte_size) {
    1926           0 :                 return LZXPRESS_ERROR;
    1927             :         }
    1928             : 
    1929           0 :         return output_pos;
    1930             : }
    1931             : 
    1932             : 
    1933             : /*
    1934             :  * lzxpress_huffman_decompress()
    1935             :  *
    1936             :  * output_size must be the expected length of the decompressed data.
    1937             :  * input_size and output_size are limited to the minimum of UINT32_MAX and
    1938             :  * SSIZE_MAX. On 64 bit machines that will be UINT32_MAX, or 4GB.
    1939             :  *
    1940             :  * @param input_bytes  memory to be decompressed.
    1941             :  * @param input_size   length of the compressed buffer.
    1942             :  * @param output       destination for the decompressed data.
    1943             :  * @param output_size  exact expected length of the decompressed data.
    1944             :  *
    1945             :  * @return the number of bytes written or -1 on error.
    1946             :  */
    1947             : 
    1948           0 : ssize_t lzxpress_huffman_decompress(const uint8_t *input_bytes,
    1949             :                                     size_t input_size,
    1950             :                                     uint8_t *output,
    1951             :                                     size_t output_size)
    1952             : {
    1953             :         uint16_t table[65536];
    1954           0 :         struct bitstream input = {
    1955             :                 .bytes = input_bytes,
    1956             :                 .byte_size = input_size,
    1957             :                 .byte_pos = 0,
    1958             :                 .bits = 0,
    1959             :                 .remaining_bits = 0,
    1960             :                 .table = table
    1961             :         };
    1962             : 
    1963           0 :         if (input_size > SSIZE_MAX ||
    1964           0 :             input_size > UINT32_MAX ||
    1965           0 :             output_size > SSIZE_MAX ||
    1966           0 :             output_size > UINT32_MAX ||
    1967           0 :             input_size == 0 ||
    1968           0 :             output_size == 0 ||
    1969           0 :             input_bytes == NULL ||
    1970             :             output == NULL) {
    1971             :                 /*
    1972             :                  * We use negative ssize_t to return errors, which is limiting
    1973             :                  * on 32 bit machines, and the 4GB limit exists on Windows.
    1974             :                  */
    1975           0 :                 return  LZXPRESS_ERROR;
    1976             :         }
    1977             : 
    1978           0 :         return lzxpress_huffman_decompress_internal(&input,
    1979             :                                                     output,
    1980             :                                                     output_size);
    1981             : }
    1982             : 
    1983             : 
    1984             : /**
    1985             :  * lzxpress_huffman_decompress_talloc()
    1986             :  *
    1987             :  * The caller must provide the exact size of the expected output.
    1988             :  *
    1989             :  * The input_size is limited to the minimum of UINT32_MAX and SSIZE_MAX, but
    1990             :  * output_size is limited to 256MB due to a limit in talloc. This effectively
    1991             :  * limits input_size too, as non-crafted compressed data will not exceed the
    1992             :  * decompressed size by very much.
    1993             :  *
    1994             :  * @param mem_ctx      TALLOC_CTX parent for the decompressed buffer.
    1995             :  * @param input_bytes  memory to be decompressed.
    1996             :  * @param input_size   length of the compressed buffer.
    1997             :  * @param output_size  expected decompressed size.
    1998             :  *
    1999             :  * @return a talloc'ed buffer exactly output_size in length, or NULL.
    2000             :  */
    2001             : 
    2002           0 : uint8_t *lzxpress_huffman_decompress_talloc(TALLOC_CTX *mem_ctx,
    2003             :                                             const uint8_t *input_bytes,
    2004             :                                             size_t input_size,
    2005             :                                             size_t output_size)
    2006             : {
    2007             :         ssize_t result;
    2008           0 :         uint8_t *output = NULL;
    2009           0 :         struct bitstream input = {
    2010             :                 .bytes = input_bytes,
    2011             :                 .byte_size = input_size
    2012             :         };
    2013             : 
    2014           0 :         output = talloc_array(mem_ctx, uint8_t, output_size);
    2015           0 :         if (output == NULL) {
    2016           0 :                 return NULL;
    2017             :         }
    2018             : 
    2019           0 :         input.table = talloc_array(mem_ctx, uint16_t, 65536);
    2020           0 :         if (input.table == NULL) {
    2021           0 :                 talloc_free(output);
    2022           0 :                 return NULL;
    2023             :         }
    2024           0 :         result = lzxpress_huffman_decompress_internal(&input,
    2025             :                                                       output,
    2026             :                                                       output_size);
    2027           0 :         talloc_free(input.table);
    2028             : 
    2029           0 :         if (result != output_size) {
    2030           0 :                 talloc_free(output);
    2031           0 :                 return NULL;
    2032             :         }
    2033           0 :         return output;
    2034             : }

Generated by: LCOV version 1.14