https://github.com/ggerganov/llama.cpp/pull/1773 Skip to content Toggle navigation Sign up * Product + Actions Automate any workflow + Packages Host and manage packages + Security Find and fix vulnerabilities + Codespaces Instant dev environments + Copilot Write better code with AI + Code review Manage code changes + Issues Plan and track work + Discussions Collaborate outside of code Explore + All features + Documentation + GitHub Skills + Blog * Solutions For + Enterprise + Teams + Startups + Education By Solution + CI/CD & Automation + DevOps + DevSecOps Resources + Customer Stories + White papers, Ebooks, Webinars + Partners * Open Source + GitHub Sponsors Fund open source developers + The ReadME Project GitHub community articles Repositories + Topics + Trending + Collections * Pricing Search or jump to... Search code, repositories, users, issues, pull requests... Search [ ] Clear Search syntax tips Provide feedback We read every piece of feedback, and take your input very seriously. [ ] [ ] Include my email address so I can be contacted Cancel Submit feedback Saved searches Use saved searches to filter your results more quickly Name [ ] Query [ ] To see all available qualifiers, see our documentation. Cancel Create saved search Sign in Sign up You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. {{ message }} ggerganov / llama.cpp Public * Notifications * Fork 5k * Star 35.4k * Code * Issues 366 * Pull requests 86 * Discussions * Actions * Projects 4 * Wiki * Security * Insights More * Code * Issues * Pull requests * Discussions * Actions * Projects * Wiki * Security * Insights New issue Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community. Pick a username [ ] Email Address [ ] Password [ ] [ ] Sign up for GitHub By clicking "Sign up for GitHub", you agree to our terms of service and privacy statement. We'll occasionally send you account related emails. Already on GitHub? Sign in to your account Jump to bottom llama : add grammar-based sampling #1773 Open ejones wants to merge 19 commits into ggerganov:master base: master Choose a base branch [ ] Branches Tags Could not load branches Branch not found: {{ refName }} {{ refName }} default Could not load tags Nothing to show {{ refName }} default Are you sure you want to change the base? Some commits from the old base branch may be removed from the timeline, and old review comments may become outdated. Change base from ejones:grammar Open llama : add grammar-based sampling #1773 ejones wants to merge 19 commits into ggerganov:master from ejones: grammar +969 -1 Conversation 59 Commits 19 Checks 24 Files changed 14 Conversation This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters Show hidden characters ejones Copy link Collaborator @ejones ejones commented Jun 9, 2023 * edited EDITED after updates Inspired by #1397 and grantslatton's CFG work, this adds an API that takes a serialized context-free grammar to guide and constrain sampling. Also adds a sample Backus-Naur form (BNF)-like syntax in main for specifying a grammar for generations. Testing (M2 Max, 30B) Chess % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n' --grammar-file grammars/chess.gbnf main: build = 674 (e550234) main: seed = 1688014137 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 main: grammar: root ::= [1] [.] [ ] move [ ] move [] root_4 move ::= move_5 move_9 root_2 ::= [1-9] root_3 [.] [ ] move [ ] move [] root_3 ::= [0-9] | root_4 ::= root_2 root_4 | root_2 move_5 ::= pawn | nonpawn | castle pawn ::= pawn_14 [a-h] [1-8] pawn_16 nonpawn ::= [NBKQR] nonpawn_10 nonpawn_11 nonpawn_12 [a-h] [1-8] castle ::= [O] [-] [O] castle_17 move_9 ::= [+#] | nonpawn_10 ::= [a-h] | nonpawn_11 ::= [1-8] | nonpawn_12 ::= [x] | pawn_13 ::= [a-h] [x] pawn_14 ::= pawn_13 | pawn_15 ::= [=] [NBKQR] pawn_16 ::= pawn_15 | castle_17 ::= [-] [O] | A good game: 1. e4 e5 2. Nf3 Nc6 3. Bb5 a6 4. Ba4 Nf6 llama_print_timings: load time = 1144.33 ms llama_print_timings: sample time = 35.87 ms / 32 runs ( 1.12 ms per token) llama_print_timings: prompt eval time = 1126.34 ms / 7 tokens ( 160.91 ms per token) llama_print_timings: eval time = 5214.99 ms / 31 runs ( 168.23 ms per token) llama_print_timings: total time = 6398.45 ms "Chess" without grammar % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n' main: build = 645 (fd0eb66) main: seed = 1686286016 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 A good game: Sir Thomas Gresham, when he was building his famous Exchange at London, had the following dialogue with a mason, whose name was Richard B llama_print_timings: load time = 1185.47 ms llama_print_timings: sample time = 21.57 ms / 32 runs ( 0.67 ms per token) llama_print_timings: prompt eval time = 1167.67 ms / 7 tokens ( 166.81 ms per token) llama_print_timings: eval time = 4977.97 ms / 31 runs ( 160.58 ms per token) llama_print_timings: total time = 6188.21 ms Arithmetic % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n' \ --grammar 'root ::= (expr "=" ws num "\n")+ expr ::= term ([-+*/] term)* term ::= ident | num | "(" ws expr ")" ws ident ::= [a-z] [a-z0-9_]* ws num ::= [0-9]+ ws ws ::= [ \t\n]*' main: build = 674 (e550234) main: seed = 1688014196 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 main: grammar: root ::= root_5 root_1 ::= expr [=] ws num [] expr ::= term expr_8 ws ::= ws_12 num ::= num_11 ws root_5 ::= root_1 root_5 | root_1 term ::= ident | num | [(] ws expr [)] ws expr_7 ::= [-+*/] term expr_8 ::= expr_7 expr_8 | ident ::= [a-z] ident_10 ws ident_10 ::= [a-z0-9_] ident_10 | num_11 ::= [0-9] num_11 | [0-9] ws_12 ::= [ ] ws_12 | Some arithmetic practice: 10 *a*1 +b*2 =640 10 *a*2 +b*3 =656 llama_print_timings: load time = 1165.00 ms llama_print_timings: sample time = 41.11 ms / 32 runs ( 1.28 ms per token) llama_print_timings: prompt eval time = 1147.76 ms / 7 tokens ( 163.97 ms per token) llama_print_timings: eval time = 5113.92 ms / 31 runs ( 164.97 ms per token) llama_print_timings: total time = 6323.27 ms Arithmetic - no grammar % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n' main: build = 645 (fd0eb66) main: seed = 1686286388 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 Some arithmetic practice: \begin{code} package main import ( "fmt" ) func main() { fmt.Println( llama_print_timings: load time = 1171.65 ms llama_print_timings: sample time = 21.37 ms / 32 runs ( 0.67 ms per token) llama_print_timings: prompt eval time = 1153.88 ms / 7 tokens ( 164.84 ms per token) llama_print_timings: eval time = 4991.68 ms / 31 runs ( 161.02 ms per token) llama_print_timings: total time = 6187.91 ms JSON % ./main -m $LLAMA_30B_Q4_0 -n 64 -p $'A bit about me:\n\n' --grammar-file grammars/json.gbnf main: build = 674 (e550234) main: seed = 1688014289 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0 main: grammar: root ::= object object ::= [{] ws object_11 [}] value ::= object | array | string | number | boolean array ::= [[] ws array_15 []] string ::= ["] string_16 ["] ws number ::= number_17 number_18 ws boolean ::= boolean_19 ws ws ::= [ ] ws | object_8 ::= string [:] ws value object_10 object_9 ::= [,] ws string [:] ws value object_10 ::= object_9 object_10 | object_11 ::= object_8 | array_12 ::= value array_14 array_13 ::= [,] ws value array_14 ::= array_13 array_14 | array_15 ::= array_12 | string_16 ::= [ !#-[]-~] string_16 | number_17 ::= [-] | number_18 ::= [0-9] number_18 | [0-9] boolean_19 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] A bit about me: { "fullName": "Ramon Rodriguez", "username": "ramon", "email": "ramon@mail.com", "phoneNumber": "+1234567890", "address": { llama_print_timings: load time = 1273.70 ms llama_print_timings: sample time = 82.93 ms / 64 runs ( 1.30 ms per token) llama_print_timings: prompt eval time = 1256.36 ms / 8 tokens ( 157.04 ms per token) llama_print_timings: eval time = 10432.05 ms / 63 runs ( 165.59 ms per token) llama_print_timings: total time = 11795.36 ms "JSON" - no grammar % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A bit about me:\n\n' main: build = 645 (fd0eb66) main: seed = 1686286615 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 A bit about me: A former teacher, now a full-time writer. I am the author of two novels: _The Man in the Moon_ and _The Riddle llama_print_timings: load time = 1291.32 ms llama_print_timings: sample time = 21.48 ms / 32 runs ( 0.67 ms per token) llama_print_timings: prompt eval time = 1274.63 ms / 8 tokens ( 159.33 ms per token) llama_print_timings: eval time = 4990.01 ms / 31 runs ( 160.97 ms per token) llama_print_timings: total time = 6306.01 ms Japanese % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Building a website can be done in 10 simple steps (from the original Japanese):\n\n' --grammar-file grammars/japanese.gbnf main: build = 674 (e550234) main: seed = 1688013430 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 main: grammar: root ::= root_2 root_5 jp-char ::= hiragana | katakana | punctuation | cjk root_2 ::= jp-char root_2 | jp-char root_3 ::= [ ] root_4 root_4 ::= jp-char root_4 | jp-char root_5 ::= root_3 root_5 | hiragana ::= [-] katakana ::= [-] punctuation ::= [-] cjk ::= [-] Building a website can be done in 10 simple steps (from the original Japanese): Yi , Mu De haHe nanoka Er , oKe samawoSi iChu shite San , oKe samanokoto llama_print_timings: load time = 2957.19 ms llama_print_timings: sample time = 42.67 ms / 32 runs ( 1.33 ms per token) llama_print_timings: prompt eval time = 2941.56 ms / 21 tokens ( 140.07 ms per token) llama_print_timings: eval time = 5384.28 ms / 31 runs ( 173.69 ms per token) llama_print_timings: total time = 8387.61 ms Japanese - no grammar % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Building a website can be done in 10 simple steps (from the original Japanese):\n\n' main: build = 674 (e550234) main: seed = 1688013483 llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 512 llama_model_load_internal: n_embd = 6656 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 52 llama_model_load_internal: n_layer = 60 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 2 (mostly Q4_0) llama_model_load_internal: n_ff = 17920 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 30B llama_model_load_internal: ggml ctx size = 0.13 MB llama_model_load_internal: mem required = 19756.66 MB (+ 3124.00 MB per state) . llama_init_from_file: kv self size = 780.00 MB system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0 Building a website can be done in 10 simple steps (from the original Japanese): 1. Determine your goal for your site. 2. Make a plan. 3. Select the domain name. 4. Choose web llama_print_timings: load time = 2955.05 ms llama_print_timings: sample time = 22.96 ms / 32 runs ( 0.72 ms per token) llama_print_timings: prompt eval time = 2937.10 ms / 21 tokens ( 139.86 ms per token) llama_print_timings: eval time = 5032.41 ms / 31 runs ( 162.34 ms per token) llama_print_timings: total time = 8013.71 ms Approach Grammar API The llama API accepts a data structure representing a context-free grammar over 32-bit code points: // grammar element type enum llama_gretype { // end of rule definition LLAMA_GRETYPE_END = 0, // start of alternate definition for rule LLAMA_GRETYPE_ALT = 1, // non-terminal element: reference to rule LLAMA_GRETYPE_RULE_REF = 2, // terminal element: character (code point) LLAMA_GRETYPE_CHAR = 3, // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // be an inclusive range ([a-z]) LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 5, }; typedef struct llama_grammar_element { enum llama_gretype type; uint32_t value; // Unicode code point or rule ID } llama_grammar_element; LLAMA_API struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index); Sampling The grammar sampling code models a nondeterministic pushdown automaton, maintaining N stacks for the possible parse states. Sampling a token is done in two steps: a sampling API that filters candidates to those that match one of the parse stacks (llama_sample_grammar) and adding the chose token to the grammar (llama_grammar_accept_token). Examples Adds --grammar and --grammar-file arguments to main taking a simple extended BNF to constrain generations. The parser for this format is implemented in examples/grammar-parser.{h,cpp}: // ... Supports character // ranges, grouping, and repetition operators. As an example, a grammar for // arithmetic might look like: // // root ::= expr // expr ::= term ([-+*/] term)* // term ::= num | "(" space expr ")" space // num ::= [0-9]+ space // space ::= [ \t\n]* The root rule identifies the start of the grammar. [DEL:## Caveats:DEL] * [DEL:the binary format makes the code harder to understand and more brittle:DEL] * [DEL:the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently:DEL] * [DEL:the 1-char lookahead sampling is probably biasing generations in a weird way; further investigation on quality of outputs is probably needed:DEL] Sorry, something went wrong. 24 PapersAnon, FNsi, lin72h, tucnak, walking-octopus, zenixls2, Okabintaro, chakflying, bullno1, m1chae1bx, and 14 more reacted with thumbs up emoji [?] 2 mudler and 1980Dragon reacted with heart emoji 12 lin72h, Green-Sky, Alumniminium, xaedes, m1chae1bx, mudler, megupta, zakkor, Vuizur, AlphaAtlas, and 2 more reacted with rocket emoji All reactions * 24 reactions * [?] 2 reactions * 12 reactions @ejones llama, main : constrain sampling to grammar fd0eb66 @ggerganov ggerganov added the high priority Very important issue label Jun 9, 2023 @howard0su Copy link Collaborator howard0su commented Jun 10, 2023 Suggest taking a file as grammar parameter and put several examples like what we did for prompts (in .\prompts folder). 5 Green-Sky, lin72h, rreed-pha, MoffKalast, and schappim reacted with thumbs up emoji 1 lin72h reacted with eyes emoji All reactions * 5 reactions * 1 reaction Sorry, something went wrong. @tobi Copy link Sponsor Collaborator tobi commented Jun 10, 2023 * edited Incredibly useful contribution. It's really amazing how much this simplifies many use cases. I agree that it would be better if the grammar came from a file. Two snags I hit while trying this out: * it crashes with --prompt-cache * any empty lines in the grammar cause a crash Some additional thoughts: * Would love to have the grammars support empty lines and comments * I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU * I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it? In many types of grammars like json or html that could really speed up generation * I think it's worth allowing to reference full tokens form the grammar. Maybe something like @" token" or @13432 Id of token. 4 m1chae1bx, schappim, Flaque, and transitive-bullshit reacted with thumbs up emoji All reactions * 4 reactions Sorry, something went wrong. @slaren Copy link Collaborator slaren commented Jun 11, 2023 Very nice! I am wondering what is the rationale for not including the parser in the llama.cpp API. Without it, most downstream users will be forced to manually make a copy of the parser in their code to support the feature, which is not great. Also for usability, I think it would be a good idea to keep a copy of the binary grammar in llama_grammar, rather than asking the users to keep the provided copy alive. The overhead would be minimal, and it would simplify the code of downstream users. 5 lin72h, xaedes, FNsi, Green-Sky, and FSSRepo reacted with thumbs up emoji All reactions * 5 reactions Sorry, something went wrong. ejones added 6 commits June 11, 2023 23:44 @ejones allow loading grammar from file 834d423 @ejones fix whitespace errors 9e77f42 @ejones handle & print parser errors 674bb08 @ejones add comments to grammar syntax and allow newlines where unambiguous 98a9587 @ejones Merge remote-tracking branch 'refs/remotes/upstream/master' into grammar 56904ca @ejones add missing include 3e78f00 @ejones Copy link Collaborator Author ejones commented Jun 12, 2023 Thanks all! Just added support for grammar files (with examples) and updated the grammar syntax to add shell-style comments and allow empty lines between rules, as well as newlines inside parenthesized groups. it crashes with --prompt-cache I wonder if that was #1699 ? If so, should be fixed now I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU Sounds cool, I don't know enough about GPU programming to comment on that myself. The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently. I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it? This is definitely possible. That said, AFAIK the token would still need to be evaluated, and that seems to be the bottleneck. Maybe the optimization comes in being able to batch eval strings of such tokens? I think it's worth allowing to reference full tokens form the grammar Neat idea. Would that be more of an optimization or to reference tokens that can't be expressed textually? what is the rationale for not including the parser in the llama.cpp API. Honestly, I was trying to reduce the changes to llama.cpp itself. Agree it would be more convenient in the API. I think it would be a good idea to keep a copy of the binary grammar Makes sense. I left that out of this round of changes - if it's desired to have the grammar parser in the llama API, this may naturally fit with that change. 1 lin72h reacted with thumbs up emoji 1 Green-Sky reacted with hooray emoji All reactions * 1 reaction * 1 reaction Sorry, something went wrong. @bullno1 Copy link Contributor bullno1 commented Jun 12, 2023 First, this is amazing work. This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead. External samplers can evolve independently of the core API. The existing functions can be kept for compatibility. AFAIK, the only thing we need is to expose the RNG. And even then, the existence of that inside a state/context is debatable. The context window is already managed by user code so why not sampling? This reminds me a lot of: https://lmql.ai/. There is also https://github.com/1rgs/jsonformer where the input is a json schema which is not always easy to express in BNF. AFAIK the token would still need to be evaluated Would it though? We just immediately add it to the context. It is done manually in user code now. Maybe the optimization comes in being able to batch eval strings of such tokens? AFAIK, that's the case. The initial prompt and the user input are submitted in a large batch. The inference loop just feed the single chosen token back until eos. The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently. The current sampling is CPU. 2 KerfuffleV2 and lin72h reacted with thumbs up emoji All reactions * 2 reactions Sorry, something went wrong. @Green-Sky Copy link Collaborator Green-Sky commented Jun 12, 2023 This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead. one of the discussion points for adding more llm generic tooling back into ggml(repo) was moving the sampler there. but afaik nothing happened yet :) All reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jun 12, 2023 There is also https://github.com/1rgs/jsonformer where the input is a json schema Was planning to tackle this next. I've got it more or less working locally in a branch off of this, at least with the examples on jsonformer's README. It uses a Python script to generate a JSON BNF that conforms to the schema. 2 lin72h and tobi reacted with thumbs up emoji All reactions * 2 reactions Sorry, something went wrong. ejones added 4 commits June 14, 2023 23:53 @ejones support alternates in root rule 421c6e1 @ejones fix bugs with empty token and EOS b876d19 @ejones adjust JSON grammar 58ca9bc @ejones remove swp file 414f251 howard0su howard0su reviewed Jun 15, 2023 View reviewed changes llama.h @@ -263,6 +289,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); /// @details Apply constraints from grammar LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); Copy link Collaborator @howard0su howard0su Jun 15, 2023 There was a problem hiding this comment. Choose a reason for hiding this comment The reason will be displayed to describe this comment to others. Learn more. [Choose a reason] Hide comment Can we make llama_grammar as a structure with two callbacks? So the other implementation of it can support context aware state machine instead? Sorry, something went wrong. All reactions Copy link Collaborator Author @ejones ejones Jun 16, 2023 There was a problem hiding this comment. Choose a reason for hiding this comment The reason will be displayed to describe this comment to others. Learn more. [Choose a reason] Hide comment Do you mean like, the caller would provide the implementation of llama_grammar (via callbacks), from which the llama API determines which tokens are valid? Sorry, something went wrong. All reactions Copy link Collaborator @howard0su howard0su Jun 16, 2023 There was a problem hiding this comment. Choose a reason for hiding this comment The reason will be displayed to describe this comment to others. Learn more. [Choose a reason] Hide comment yes, so llama code will not assume the grammar implementation. Sorry, something went wrong. All reactions Copy link Collaborator Author @ejones ejones Jun 18, 2023 There was a problem hiding this comment. Choose a reason for hiding this comment The reason will be displayed to describe this comment to others. Learn more. [Choose a reason] Hide comment Yeah, I'm open to that idea, assuming the grammar interface itself generalizes well to other implementations. I kind of designed this with the specific implementation in mind so that's not a guarantee. Sorry, something went wrong. All reactions @ggerganov Copy link Owner ggerganov commented Jun 15, 2023 Great stuff! I'm still wrapping my head around this. * Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently * What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases 2 lin72h and PapersAnon reacted with thumbs up emoji All reactions * 2 reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jun 16, 2023 + Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is? + What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases Yes, that's correct. I think that's doable, I can take a stab at that. All reactions Sorry, something went wrong. @ejones ejones mentioned this pull request Jun 16, 2023 examples : generate JSON according to schema #1887 Draft @SlyEcho Copy link Sponsor Collaborator SlyEcho commented Jun 16, 2023 * edited the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently [DEL:I don't understand this part. So it is converting to UTF-16? :DEL] [DEL:Another option would be to use token values but it will be more limiting.:DEL] EDIT: I read through the code. The grammar doesn't care about the text encoding. It could work with any encoding, provided that the rules match the characters correctly. The parser doesn't understand UTF-8 so it will create rules that don't match as the user expects. For example, if I wanted to create a rule to match all Hiragana characters, I should be able to write: [a-[?]] However the parser doesn't see it as two characters separated by -, instead: [\xe3\x81\x81-\xe3\x82\x96] But the correct rule should be something like this? "\xe3" [\x81-\x82] [\x81-\x96] All reactions Sorry, something went wrong. SlyEcho SlyEcho reviewed Jun 16, 2023 View reviewed changes llama.cpp Outdated Show resolved Hide resolved SlyEcho SlyEcho reviewed Jun 16, 2023 View reviewed changes examples/grammar-parser.cpp Outdated Show resolved Hide resolved @ivanstepanovftw Copy link Collaborator ivanstepanovftw commented Jun 16, 2023 Just dont use repeat penalties to get best grammar as llama can All reactions Sorry, something went wrong. @ggerganov Copy link Owner ggerganov commented Jun 16, 2023 To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is? It's fine the way it is 1 lin72h reacted with thumbs up emoji All reactions * 1 reaction Sorry, something went wrong. @burke Copy link Sponsor burke commented Jun 16, 2023 FWIW I'm adapting this code into an analogous feature for models running on torch. In my implementation, I'm doing grammar enforcement logit masking on the GPU across the full token set before selecting candidates: https://github.com/Shopify/torch-grammar/blob/ df23e354083c909c70120e256ed34036c93f6714/grammar_sampler.py#L232-L239 . The same strategy would probably work here if anyone was super motivated to try it. All reactions Sorry, something went wrong. ggerganov ggerganov reviewed Jun 17, 2023 View reviewed changes llama.h Outdated Show resolved Hide resolved 19 hidden items Load more... @ejones Copy link Collaborator Author ejones commented Jul 1, 2023 @mattpulver I reproduced the segfault and it appears the problem in this case is left-recursive rules like query-expression -> non-join-query-expression -> query-expression or query-term -> non-join-query-term -> query-term. Since the grammar is processed top-down (and there's no special handling of left recursion), it's infinitely recursing on expanding these initial references. A workaround would be to adjust the rules to eliminate left-recursion. E.g., maybe something along the lines of this? query-expression ::= query-term (("UNION" | "EXCEPT") "ALL"? corresponding-spec? query-term)* | ... (other cases if needed) We could probably at least detect left recursion and error out early. Longer term, we could look into an implementation that would support such grammars. 2 lin72h and tmostak reacted with eyes emoji All reactions * 2 reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jul 1, 2023 @tucnak at this point the grammars are defined over code points, with no specific handling or recognition of grapheme clusters. A particular grammar could recognize grapheme clusters (and in effect normalize them), and maybe that's what you're working towards with your solution? All reactions Sorry, something went wrong. @mudler Copy link Contributor mudler commented Jul 2, 2023 * edited great contribution @ejones kudos! can't look forward to see this merged! Just for reference, been trying this locally and works like a charm! I've went a bit ahead and tried this with the golang bindings ( branch at: https://github.com/go-skynet/go-llama.cpp/tree/grammar ) and LocalAI (https://github.com/go-skynet/LocalAI), result is that now is possible to emulate OpenAI functions and run directly their examples: localai-functions-1 What I wanted to do is give more data points and highlight that it chooses correctly also to not use any of the functions, so from a first hands-on with it (and from a personal empirical set of tests) it looks that the 1-char lookahead sampling is good enough if the model is "good" enough (I've tested with WizardLM 7b), but it's very sensible to the prompt: functions-2 5 Green-Sky, ggerganov, lin72h, eliot-akira, and ejones reacted with rocket emoji All reactions * 5 reactions Sorry, something went wrong. This was referenced Jul 2, 2023 feature: Chat completion functions go-skynet/LocalAI#588 Closed feature: constrained grammars go-skynet/LocalAI#354 Closed wip: add constrained grammar support go-skynet/go-llama.cpp#124 Draft @mudler Copy link Contributor mudler commented Jul 4, 2023 * edited JFYI After playing with it a bit more, I've bumped into this while trying on ARM64+CUDA: Jul 04 20:37:13 localhost local-ai[34380]: LLAMA_ASSERT: /usr/local/LocalAI/go-llama/llama.cpp/llama.cpp:2479: !new_stacks.empty() Jul 04 20:37:13 localhost local-ai[34380]: SIGABRT: abort This is me trying with the binding. Update: I can't reproduce with llama.cpp. Ignore me, must be something in the binding which is not correct. 1 lin72h reacted with eyes emoji All reactions * 1 reaction Sorry, something went wrong. @ejones Merge remote-tracking branch 'upstream/master' into grammar 38fbd40 @ejones Copy link Collaborator Author ejones commented Jul 6, 2023 @mudler thanks for the feedback! Re: OpenAI functions, I also have a draft up at #1887 with a script to convert JSON schemas to grammars. Re: that assertion, that is triggered when the sampled token doesn't match the grammar at all. If you do run into it on llama.cpp, let me know the inputs and happy to look into it :). 1 lin72h reacted with eyes emoji All reactions * 1 reaction Sorry, something went wrong. @mudler Copy link Contributor mudler commented Jul 6, 2023 * edited @mudler thanks for the feedback! Re: OpenAI functions, I also have a draft up at #1887 with a script to convert JSON schemas to grammars. Re: that assertion, that is triggered when the sampled token doesn't match the grammar at all. If you do run into it on llama.cpp, let me know the inputs and happy to look into it :). yes, awesome job! I've looked at that PR indeed, and slightly adapted to Golang to generate grammars directly from the requests - my first attempts where more simple though, with a chain of let first choose an action -> and then fill the params (to force it to some kind of 'reasoning' step) However, what I'm seeing is quite weird - it doesn't happen when using llama.cpp directly, but only when using it with the golang bindings (https://github.com/go-skynet/go-llama.cpp) on a particular setup I have (it just happens on ARM, on my x86_64 machine just runs fine). The same grammar (basically it's equivalent output from your sample in #1887 ) works fine on x86_64 but crashes with ARM+CUDA with the error above. I'm suspecting something weird going on in the toolchain package combination (gcc/nvcc) - I've tried to trace it with gdb back with no luck so far, seems indeed that there is no match with the grammar rules (even if it does match on x86_64!). I really appreciate your help! Thank you so much, but I don't want to bother you. It seems like running llama.cpp directly isn't causing any issues. I'll collect more data and see if I can figure out if it's something that can be replicated or not. 1 lin72h reacted with thumbs up emoji All reactions * 1 reaction Sorry, something went wrong. @mudler mudler mentioned this pull request Jul 6, 2023 feat: LocalAI functions go-skynet/LocalAI#726 Merged 1 task @KerfuffleV2 KerfuffleV2 mentioned this pull request Jul 7, 2023 Add optional llm_samplers sampler backend rustformers/llm#359 Open @KerfuffleV2 Copy link Collaborator KerfuffleV2 commented Jul 8, 2023 Perhaps consider using a more standard BNF syntax? For example: https://mdkrajnak.github.io/ebnftest/ The changes to existing grammars would be pretty minor: root ::= jp-char+ (#'[ \t\n]' jp-char+)* jp-char ::= hiragana | katakana | punctuation | cjk hiragana ::= #'[a-[?]]' katakana ::= #'[a-]' punctuation ::= #'[, -]' cjk ::= #'[Yi -]' Looks like the only change really is quoting character classes like # '[blah-blah]' instead of using bare [blah-blah]. You don't have to support EBNF on your side, just using a similar format for the features you do support would make reuse of existing EBNF grammars a lot easier. 1 lin72h reacted with thumbs up emoji All reactions * 1 reaction Sorry, something went wrong. @ejones add unicode escapes 014fbfd @ejones Copy link Collaborator Author ejones commented Jul 12, 2023 @KerfuffleV2 when digging into EBNF (basically on Wikipedia), I concluded that there isn't a clear, single standard EBNF to follow. I opted for a format that would match typical modern regex syntax. For this, XML's EBNF was a good starting point, but using \x/ \u / \U escapes in place of their #x. In the project you linked, the #'...' syntax seems to be an artifact of Clojure: This project uses the clojurescipt port of instaparse which provides an excellent EBNF parser with one quirk: regexs must be quoted as clojure regex literals. For example the regex for any character A-Z, [A-Z], is entered as #'[A-Z]'. I like the idea of reusing existing grammars. Perhaps it makes sense to implement alternate parsers or translators for the most popular formats? For example, it looks like ANTLR publishes a large set of example grammars, but I don't know that we'd want to tie our main format to ANTLR to benefit from that. All reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jul 12, 2023 @SlyEcho gentle nudge - any comments on this? 3 lin72h, marclove, and gururise reacted with thumbs up emoji All reactions * 3 reactions Sorry, something went wrong. @ejones Merge remote-tracking branch 'upstream/master' into grammar b2e071d @wanicca Copy link wanicca commented Jul 12, 2023 * edited Hi, after reading the code I came up with a question. Say we have a rule that the model must generate either apple or banana. Current implementation will reserve all tokens with prefix a as the filtered candidates, such as ant, ate, apple. So if ant is sampled, it will be truncated into a. Eventually, we may get a,p,p,l,e (five tokens) rather than apple (one token). Did I mistake something? 1 gururise reacted with eyes emoji All reactions * 1 reaction Sorry, something went wrong. @SlyEcho Copy link Sponsor Collaborator SlyEcho commented Jul 12, 2023 @SlyEcho gentle nudge - any comments on this? Sorry, I haven't found time to test it yet but it looks pretty good already. All reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jul 13, 2023 @wanicca Eventually, we may get a,p,p,l,e (five tokens) rather than apple (one token). Correct. Although even with this 1-char implementation, in cases where one or more longer tokens are uniquely determined by the grammar (e.g., "a" implies "apple"), the batch sampling optimization suggested by @ggerganov may also fix this. All reactions Sorry, something went wrong. @SlyEcho Copy link Sponsor Collaborator SlyEcho commented Jul 13, 2023 I suppose we could have backtracking sometime in the future? 2 ggerganov and lin72h reacted with thumbs up emoji All reactions * 2 reactions Sorry, something went wrong. @KerfuffleV2 Copy link Collaborator KerfuffleV2 commented Jul 14, 2023 In the project you linked, the #'...' syntax seems to be an artifact of Clojure: Ah, fair enough. It seems like the Rust enbf crate conforms to that ENBF tester for what it's worth. Although even with this 1-char implementation, in cases where one or more longer tokens are uniquely determined by the grammar [...] Isn't this more of a problem than it might seem at first glance because models are trained with certain combinations of tokens? So causing the model to generate a,p,p,l,e piecemeal would likely cause the model to generate worse output and the reason why wouldn't really be apparent because it still just looks like "apple" to the user. All reactions Sorry, something went wrong. @marclove Copy link marclove commented Jul 14, 2023 In the project you linked, the #'...' syntax seems to be an artifact of Clojure: Ah, fair enough. It seems like the Rust enbf crate conforms to that ENBF tester for what it's worth. Although even with this 1-char implementation, in cases where one or more longer tokens are uniquely determined by the grammar [...] Isn't this more of a problem than it might seem at first glance because models are trained with certain combinations of tokens? So causing the model to generate a,p,p,l,e piecemeal would likely cause the model to generate worse output and the reason why wouldn't really be apparent because it still just looks like "apple" to the user. If you're constraining your output based on a grammar, isn't violating the grammar the worst possible output? At least in the use cases I can think of, if I'm constraining on grammar, that's my first priority; qualitative measures separate from that are secondary. Thanks for the work, @ejones. Looking forward to seeing this merged. 1 lin72h reacted with eyes emoji All reactions * 1 reaction Sorry, something went wrong. @KerfuffleV2 Copy link Collaborator KerfuffleV2 commented Jul 15, 2023 If you're constraining your output based on a grammar, isn't violating the grammar the worst possible output? It depends. That's certainly true sometimes, but unless the cons are clearly documented/apparent then users also might use grammar sampling to direct responses in the direction they prefer without necessarily understanding the tradeoff. if I'm constraining on grammar, that's my first priority; qualitative measures separate from that are secondary. I'd say it depends again, even if sticking to the grammar is paramount you still have to end up with a result with a quality level that's worth using for your application. --------------------------------------------------------------------- Anyway, I want to be clear, I'm not opposing this pull or saying it shouldn't be merged or anything like that. My motivation for bringing up that point is in interests of improving llama.cpp either by considering/exploring other approaches that may not involve these tradeoffs or ensuring they're documented when they exist. I'm not just being negative/contrary. 2 lin72h and Rybens92 reacted with eyes emoji All reactions * 2 reactions Sorry, something went wrong. @ejones Copy link Collaborator Author ejones commented Jul 18, 2023 causing the model to generate a,p,p,l,e piecemeal would likely cause the model to generate worse output and the reason why wouldn't really be apparent because it still just looks like "apple" to the user Yes, I agree. I think the instructions and examples in the prompt can reduce the likelihood of this, but it is a problem. Maybe I should revisit the 1-char assumption. 3 lin72h, KerfuffleV2, and Rybens92 reacted with thumbs up emoji All reactions * 3 reactions Sorry, something went wrong. ejones added 2 commits July 18, 2023 22:34 @ejones add inverse char ranges 8d37755 @ejones only sample full tokens (no peeking or truncation) c047e8a @ejones Copy link Collaborator Author ejones commented Jul 21, 2023 @wanicca @KerfuffleV2 @ggerganov the latest update restricts sampling to complete token matches, removing the 1-char peeking and token truncation. I noticed it's still possible for the model to inadvertently split up tokens, as it may select a prefix of another token where specific sequences are expected (e.g., selecting the "a" or "b" tokens at "apple" | "banana"). There's probably a further improvement to select the longest matching token in such cases, or something. 1 KerfuffleV2 reacted with thumbs up emoji 4 PapersAnon, Okabintaro, Rybens92, and mudler reacted with hooray emoji All reactions * 1 reaction * 4 reactions Sorry, something went wrong. @ggerganov llama : minor style changes ... 11315b1 blindly applied in online editor - hopefully I didn't break something ggerganov ggerganov approved these changes Jul 21, 2023 View reviewed changes Copy link Owner @ggerganov ggerganov left a comment * edited There was a problem hiding this comment. Choose a reason for hiding this comment The reason will be displayed to describe this comment to others. Learn more. [Choose a reason] Hide comment Great stuff! Looking forward to playing with this Merge it when you wish Sorry, something went wrong. 2 kvey and wokkaflokka reacted with rocket emoji All reactions * 2 reactions Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment Reviewers @howard0su howard0su howard0su left review comments @SlyEcho SlyEcho SlyEcho left review comments @ggerganov ggerganov ggerganov approved these changes Assignees @ejones ejones Labels enhancement New feature or request generation quality Quality of model output high priority Very important issue Projects ggml : roadmap Status: In Progress Milestone No milestone Development Successfully merging this pull request may close these issues. None yet 17 participants @ejones @howard0su @tobi @slaren @bullno1 @Green-Sky @ggerganov @SlyEcho @ivanstepanovftw @burke @mattpulver @zakkor @tucnak @mudler @KerfuffleV2 @wanicca @marclove Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later. Footer (c) 2023 GitHub, Inc. Footer navigation * Terms * Privacy * Security * Status * Docs * Contact GitHub * Pricing * API * Training * Blog * About You can't perform that action at this time.