https://compilade.net/blog/ternary-packing Home Blog How to pack ternary numbers in 8-bit bytes with efficient SIMD-friendly unpacking --------------------------------------------------------------------- Published: 2024-06-26 There are 3 possible values in a digit of a ternary number. 3 possible values, which could actually be anything. -1 0 1 0 1 2 I've been recently nerd-sniped^1 into trying to pack the ternary weights of BitNet b1.58 into something close to that theoretical ideal of log(3) / log(2) bits^2 per ternary digit. I'll be calling a "ternary digit" a "trit", like a "binary digit" is called a "bit". Block size Since the goal of this is to allow fast parallel unpacking, blocks of trits can't be infinitely big. A small "block" size needs to be found, ideally one which is both efficient with information density and which is convenient on current hardware. To find a good block size, we'll need to find a power of 3 for which the next power of 2 is very close. trits 3^trits bits 2^bits bits per trit 1 3 2 4 2 2 9 4 16 2 3 27 5 32 1.666... 4 81 7 128 1.75 5 243 8 256 1.6 It's very fortunate that 5 trits fit quite tight into 8 bits at 1.6 bits per trit. When compared to perfect packing, this is 99.06% efficient. 1.6 bits per trit The basic idea with this packing scheme is simply to make a number out of the ternary digits. def pack_number(digits: list[int], base: int) -> int: number = 0 for digit in digits: assert digit < base number = number * base number = number + digit return number Packing trits into bytes should be similar enough. Fast multiplication unpacking While repeated remainder and divisions can be used to extract the digits of a number, the problem with divisions and modulo is that they are not usually supported on integers in SIMD programming. A way around this is obviously to view numbers differently. Would it be nice if instead of extracting the least significant digit with modulo, we could extract the most significant digit with a multiplication? Fixed point numbers to the rescue! 0x7F. 11201. .11201 0x0.86 0x86. same number divide by 243 same number, round up multiply by 256 Tada! Now digits can be easily extracted from the top two bits of the resulting 10-bit number when multiplying this 8-bit byte by 3. This is much more convenient than modulo when unpacking with SIMD. The only place where there are divisions in this scheme when packing trits into bytes. This assumes that packing is done less often than unpacking, which is very true in the context of LLM weights. # Take a list of values in -1, 0, 1 and pack them in bytes def pack_trits(digits: list[int]) -> bytearray: assert len(digits) % 5 == 0 # padding isn't handled here n_bytes = len(digits) // 5 packed = bytearray() for i in range(n_bytes): b = 0 for j in range(5): digit = digits[5*i + j] digit = max(-1, min(digit, 1)) # clamp between -1 and 1 digit += 1 # from -1, 0, 1 to 0, 1, 2 b *= 3 b += digit b = ((b * 256) + (243 - 1)) // 243 packed.append(b) return packed The relevant interesting line is this one: b = ((b * 256) + (243 - 1)) // 243 It does what is depicted in the diagram above, but multiplication is done first because these are integer operations. Doing a ceiling division here is necessary to cancel the off-by-one error from truncating when extracting digits later. To unpack without using the modulo operator: def unpack_trits(packed: bytes) -> list[int]: trits: list[int] = [] for byte in packed: b = byte for i in range(5): b = b * 3 trit = b >> 8 trits.append(trit - 1) # 0, 1, 2 => -1, 0, 1 b = b & 0xFF return trits To convince myself that this works, I wrote a C program checking that this really is lossless: #include #include #include int main(void) { char s1[6] = {0}; char s2[6] = {0}; for (uint8_t i = 0; i < 243; ++i) { uint8_t n = i; // Get the number representation in base 3 // by repeatedly extracting the least significant digit with modulo for (int j = 5; j-- > 0;) { s1[j] = (n % 3) + '0'; n /= 3; } // Turn that number into a fixed-point number smaller than 1 uint8_t q = (((uint16_t) i) * 256 + (243 - 1)) / 243; // This extracts the most significant digit first for (int j = 0; j < 5; ++j) { uint16_t m = q * 3; s2[j] = (m >> 8) + '0'; q = m & 0xFF; } printf("%s, %s: %s\n", s1, s2, strcmp(s1, s2) == 0 ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m"); } return 0; } Compile and run with: $ gcc ternary-packing.c -o ternary-packing $ ./ternary-packing And I'm getting PASS for each of the 243 ternary numbers which fit in 8 bits. And this is the technique used in the upcoming 1.625 bpw quant in llama.cpp for BitNet b1.58, for which the pull request is https:// github.com/ggerganov/llama.cpp/pull/8151, with SIMD implementations for both AVX2 and ARM NEON. --------------------------------------------------------------------- 1. obviously referring to https://xkcd.com/356/, but the initial motivation actually started from this review comment I posted on the initial BitNet b1.58 pull-request for llama.cpp - 2. log(3) / log(2) is also known as 1.584962500721156. - Copyright (c) 2023-2024 by Compilade (CC BY-SA 4.0). Code snippets dedicated to the Public Domain (CC0 1.0) | Source | Feed