https://github.com/karpathy/llm.c/discussions/677 Skip to content Navigation Menu Toggle navigation Sign in * Product + Actions Automate any workflow + Packages Host and manage packages + Security Find and fix vulnerabilities + Codespaces Instant dev environments + GitHub Copilot Write better code with AI + Code review Manage code changes + Issues Plan and track work + Discussions Collaborate outside of code Explore + All features + Documentation + GitHub Skills + Blog * Solutions By size + Enterprise + Teams + Startups By industry + Healthcare + Financial services + Manufacturing By use case + CI/CD & Automation + DevOps + DevSecOps * Resources Topics + AI + DevOps + Innersource + Open Source + Security + Software Development Explore + Learning Pathways + White papers, Ebooks, Webinars + Customer Stories + Partners * Open Source + GitHub Sponsors Fund open source developers + The ReadME Project GitHub community articles Repositories + Topics + Trending + Collections * Enterprise + Enterprise platform AI-powered developer platform Available add-ons + Advanced Security Enterprise-grade security features + GitHub Copilot Enterprise-grade AI features + Premium Support Enterprise-grade 24/7 support * Pricing Search or jump to... Search code, repositories, users, issues, pull requests... Search [ ] Clear Search syntax tips Provide feedback We read every piece of feedback, and take your input very seriously. [ ] [ ] Include my email address so I can be contacted Cancel Submit feedback Saved searches Use saved searches to filter your results more quickly Name [ ] Query [ ] To see all available qualifiers, see our documentation. Cancel Create saved search Sign in Sign up You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert {{ message }} karpathy / llm.c Public * Notifications You must be signed in to change notification settings * Fork 2.3k * Star 21.6k * Code * Issues 61 * Pull requests 85 * Discussions * Actions * Projects 0 * Security * Insights Additional navigation options * Code * Issues * Pull requests * Discussions * Actions * Projects * Security * Insights Let's reproduce GPT-2 (1.6B): one 8XH100 node, 24 hours, $672, in llm.c #677 karpathy started this conversation in General Let's reproduce GPT-2 (1.6B): one 8XH100 node, 24 hours, $672, in llm.c #677 @karpathy karpathy Jul 11, 2024 * 6 comments * 6 replies Return to top Discussion options * {{title}} Something went wrong. Quote reply edited * {{editor}}'s edit {{actor}} deleted this content . {{editor}}'s edit Something went wrong. [241] karpathy Jul 11, 2024 Maintainer - In this post we are reproducing GPT-2 in llm.c. This is "the GPT-2", the full, 1558M parameter version that was introduced in OpenAI's blog post Better Language Models and their Implications in February 14, 2019. llm.c does so directly in C/CUDA (total of ~5,000 lines of code), without the typical training stack that would involve the Python interpreter and a significantly more complex deep learning library like PyTorch/JAX, huggingface/transformers, or etc. In 2019, training GPT-2 was an involved project from an entire team and considered a big model run but, ~5 years later, due to improvements in compute (H100 GPUs), software (CUDA, cuBLAS, cuDNN, FlashAttention) and data (e.g. the FineWeb-Edu dataset), we can reproduce this model on a single 8XH100 node in 24 hours, and for $672, which is quite incredible. There are some caveats and asterisks involved though - llm.c is still not perfectly tuned, sufficiently stabilized (we still see loss spikes and bad activation ranges now and then), and our evals are not comprehensive (e.g. we do not carefully eval multilingual, code, math); A lot of this work is still ongoing. Unicorns. The natural place to start is probably with unicorns. In the GPT-2 blog post, OpenAI prompted GPT-2 with this unicorn prompt. Let's see what our model thinks about English-speaking unicorns in the Andes mountains: Prompt: In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English. Completion: It may sound like a hoax, or just a dream of the mythical animals, but it appears this may be a real case of the unicorn language. The Andean area has an abundance of minerals and is also very rich in biodiversity. According to the researchers' study, the unicorns' language is called Elveseo. To produce it, the animal forms a consonant from the lips and makes a vowel sound with the teeth, then produces a consonant sound with the mouth and makes an 'ow' sound from its chest. "The question is, what caused this, and what does this tell us about intelligence and communication?" commented study co-author Paul Marek. Speaking in the video, the unicorns had the right way to speak, the animals were able to communicate and talk in perfect English. Fun! :) The model is fairly coherent and qualitatively somewhere around the level of GPT-2. You can find 20 samples from both GPT-2 and the llm.c model here, or generate many more using instructions down below. Training. Training a GPT-2 with llm.c is quite simple because it is written in C/CUDA, so there is no need for minconda, Python, PyTorch, etc. You will want an 8XH100 GPU box, I recommend spinning one up from Lambda labs. But llm.c is flexible on its compute - if you have only 1 GPU you can still get your GPT-2, you'll just have to wait 8 days instead of 1. If you have 16 GPUs (e.g. using the new Lambda 1 Click Clusters), you'll be able to train multinode and only have to wait 12 hours. Once you spin up your node, here are the complete instructions to train your GPT-2 (this only takes a ~minute from blank box to start stepping): # install cudnn so we can use FlashAttention and run fast (optional) # https://developer.nvidia.com/cudnn-downloads # for me, CUDA 12 (run `nvcc --version`) running on Linux x86_64 Ubuntu 22.04 wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get -y install libcudnn9-dev-cuda-12 # "install" cudnn-frontend to ~/ git clone https://github.com/NVIDIA/cudnn-frontend.git # install MPI (optional, if you intend to use multiple GPUs) # (you might also have to install NVIDIA NCCL if it doesn't come with your setup) sudo apt -y install openmpi-bin openmpi-doc libopenmpi-dev # download and enter llm.c repo git clone https://github.com/karpathy/llm.c.git cd llm.c # download the "starter pack" (~1GB download) # contains GPT2-124M weights (used in tests), tokenizer, eval data .bin s ./dev/download_starter_pack.sh # download the training dataset (FineWeb-Edu 100B token) .bin data shards # note: this is a total of 1001 data shards. If you only want to test things # out and don't want to do an actual run, feel free to append the number of # training shards to download (e.g. for just 10 shards: ./edu_fineweb.sh 10) # the full dataset is ~200GB, we can store it here in dev/data directory. cd dev/data ./edu_fineweb.sh # compile (~1 min 1st time for cuDNN mostly, few sec from then on) cd ../../ make train_gpt2cu USE_CUDNN=1 # and train! (wait 24 hours here) mpirun -np 8 ./train_gpt2cu \ -i "dev/data/edu_fineweb100B/edu_fineweb_train_*.bin" \ -j "dev/data/edu_fineweb100B/edu_fineweb_val_*.bin" \ -o "log_gpt2_1558M" \ -v 250 -s 300000 -g 384 \ -h 1 \ -b 16 -t 1024 \ -d 1048576 \ -r 0 \ -z 1 \ -c 0.1 \ -k "cosine" \ -l 0.0006 \ -q 0.1 \ -u 700 \ -n 2000 \ -x 32000 \ -ge 1 \ -y 1 \ -e "d48" I will describe the args in a second. You'll see a bunch of prints scroll through and then the optimization will begin: num_parameters: 1557686400 => bytes: 3115372800 allocated 2971 MiB for model parameters batch_size B=16 * seq_len T=1024 * num_processes=8 and total_batch_size=1048576 => setting grad_accum_steps=8 created directory: log_gpt2_1558M allocating 40409 MiB for activations val loss 11.129390 allocating 2971 MiB for parameter gradients allocating 742 MiB for AdamW optimizer state m allocating 742 MiB for AdamW optimizer state v allocating 742 MiB for master copy of params step 1/32000 | loss 11.133732 (+nanz)| norm 52.9732 (+nanz)| lr 8.57e-07 | 3056.36 ms | 42.6% bf16 MFU | 343080 tok/s step 2/32000 | loss 10.539388 (+nanz)| norm 43.5996 (+nanz)| lr 1.71e-06 | 2747.19 ms | 47.4% bf16 MFU | 381690 tok/s step 3/32000 | loss 9.894109 (+nanz)| norm 23.2229 (+nanz)| lr 2.57e-06 | 2753.25 ms | 47.3% bf16 MFU | 381259 tok/s step 4/32000 | loss 9.566241 (+nanz)| norm 28.4920 (+nanz)| lr 3.43e-06 | 2741.47 ms | 47.5% bf16 MFU | 381690 tok/s step 5/32000 | loss 9.482848 (+nanz)| norm 23.7817 (+nanz)| lr 4.29e-06 | 2752.07 ms | 47.3% bf16 MFU | 381507 tok/s step 6/32000 | loss 9.332832 (+nanz)| norm 15.9113 (+nanz)| lr 5.14e-06 | 2751.01 ms | 47.3% bf16 MFU | 381431 tok/s step 7/32000 | loss 9.165650 (+nanz)| norm 10.5941 (+nanz)| lr 6.00e-06 | 2753.03 ms | 47.3% bf16 MFU | 381327 tok/s step 8/32000 | loss 9.132234 (+nanz)| norm 16.2733 (+nanz)| lr 6.86e-06 | 2748.91 ms | 47.3% bf16 MFU | 381348 tok/s step 9/32000 | loss 9.097384 (+nanz)| norm 12.1342 (+nanz)| lr 7.71e-06 | 2748.73 ms | 47.3% bf16 MFU | 381367 tok/s step 10/32000 | loss 9.072879 (+nanz)| norm 10.5923 (+nanz)| lr 8.57e-06 | 2749.40 ms | 47.3% bf16 MFU | 381369 tok/s ... We can see that each step is about 2.75 seconds and there are 32,000 of them, so now we wait ~24 hours. At every step, this training run takes a chunk of ~1 million tokens of FineWeb-EDU (these are educational web pages from the internet), and updates the 1558 million weights of the model to be slightly better at predicting the next token in a sequence. By the end we'll have processed 32,000 * 1048576 = 33.6B tokens in total. The loss goes down as we do a better job predicting the next token. The norm will stabilize around 0.1-1, the learning rate is being warmed up over the first few steps. Our model flops utilization (MFU) is around 50%, i.e. quite efficient. Now wait 24 hours for this to finish, then you can visualize the main.log log file using the dev/vislog.ipynb jupyter notebook. For this you will need to also have Python and matplotlib installed, and you will see the following: image Evals. On the left we are tracking the loss on FineWeb-EDU validation data. If you simply run the GPT-2 released by OpenAI and evaluate its loss on this split, you get the red horizontal line (loss 2.83). You see that our run outperforms this very very quickly, by step ~5,000. However, this is not a fair comparison because GPT-2 was trained on the never-released WebText dataset, so there is a possibly large distribution shift. So e.g. if you finetune the OpenAI model for 1,000 steps at LR 1e-4, the loss quickly plunges to the blue line (loss 2.61), because it's quickly adapting to the new data statistics. I like to look at the validation loss as a sanity check, but for the actual comparison we'd want to look at fixed, 3rd party evaluations. One of the well-behaved, smooth, common, often-cited evals that also offer early signal is the HellaSwag eval. These are simple common sense scenarios and the model has to pick the correct continuation. We evaluate HellaSwag on the right pane, where we see that we cross over the GPT-2 model around step ~25K (earlier than GPT-2, which is estimated to have been trained on ~100B tokens. This possibly has to do with increased data quality, as we also observed in our earlier 124M run). The green line is the GPT-3 model of the same size, which is pretty much the same model architecture as GPT-2 with minor differences (context length 1024 -> 2048) but trained for 300B tokens (i.e. ~10X more tokens than what we trained on here). I should say that even HellaSwag is not an ideal single point of comparison because it tests simple English and common sense, it does not test e.g. multilingual, math or code. It could have been that the WebText data mixture was a lot heavier on these, and these domains were "stealing" model capacity to some extent, we don't know because it was never released. Lastly, in general, good evals are harder at low model capability like GPT-2 because e.g. the models don't understand multiple choice, and their samples are not high enough quality to make above chance dent into standard math or code evals. Args guide. Let's look at the args we passed into the training now in more detail. The GPT-2 release from OpenAI included model weights but very few details, while GPT-3 release had no weights but many details. So in many cases, we follow the GPT-3 paper hyperparameters because the GPT-2 paper has very very little information: * -i -j are training and validation splits token files, downloaded earlier with edu_fineweb.sh * -o is the output directory to write logs and checkpoints into * -v 250 asks to evaluate and log the validation loss every 250 steps * -s 300000 asks to sample some tokens every 300000 steps. Because the total number of steps will be less than this, this is hacky way to turn sampling off and we will only sample a single time at the very end. * -g 384 sets the number of tokens to be sampled at the end to be 384 * -h 1 asks to evaluate the HellaSwag accuracy * -b 16 sets the micro-batch size to 16 . If you are running out of memory, decrease this value, e.g. try 8, 4, 2, all the way down to 1 potentially. * -t 1024 sets the maximum sequence length to 1024, as GPT-2 did * -d 1048576 asks that the total batch size be 2 to the power 20, following the GPT-3 paper hyperparameters table. The code will make sure to meet this desired total batch size and calculate the needed gradient accumulation "inner loop" steps of the optimization. For example up above, we saw that we have 8 GPUs each doing 16 X 1024 tokens, so that is 8 X 16 X 1024 = 131,072 tokens per micro-step (a single forward backward), so the code calculated gradient accumulation steps of 8 to meet the desired 1M batch size per step. i.e. it does forward+backward 8 times and then a single update. * -r 0 sets recompute to zero. Recompute is a way to trade off compute and memory. If -r 1, then we recompute a piece of the forward pass (the GeLU) during backward. This means we don't have to cache it and save memory, at the cost of some more compute. So if you're running out of memory, try -r 1, or -r 2 (also recompute layernorms). * -z 1 turns on ZeRO-1 (i.e. optimizer state sharding) across multiple GPUs. If you're training with > 1 GPU, this setting is a no-brainer and should basically always be on. On 1 GPU this setting is a no-op. * -c 0.1 sets the weight decay to 0.1. Only (2D) weights are decayed exactly as in GPT-2, and this number comes from the GPT-3 paper * -k "cosine" sets the cosine learning rate schedule, which is the default so this is a bit spurious. * -l 0.0006 sets the maximum learning rate to 6e-4. The GPT-3 paper says to use 2e-4 for this model size, but here we triple and it and seems to train faster and without any issues. This wasn't tuned very carefully yet. * -q 0.1 says that we will decay the learning rate to 10% of max LR over the course of training, following GPT-3 paper. * -u 700 says that we will ramp up the learning rate from 0 to max learning rate over the first 700 iterations, which at total batch size 0.5M is 350M tokens, following GPT-3 paper. * -n 2000 asks to save model checkpoints every 2000 steps. * -x 32000 asks for 32K steps in total. I chose this number because it is a nice number, and just fits into 24 hours. * -ge 1 sets a very recently merged gelu recompute setting for CublasLt (optional) * -y 1 sets the "resume" flag on. If your training for any reason crashes or hangs, you can CTRL+C and re-run this command, and it will attempt to resume the optimization. llm.c is bitwise-deterministic, so you'll get the identical result as if you didn't crash. * -e "d48" asks to initialize, a depth 48 GPT-2 model from scratch. Memory guide. The biggest constraint most people will probably face is that their GPU doesn't have 80GB. That's okay you should still be able to run everything above if you are patient, it would just run slower. So if the model doesn't fit, what do you play with? The most important one is the micro batch size -b. Try to decrease it but keep it to nice numbers. So e.g. 16 -> 8 -> 4 -> 2 -> 1. From there, try to also play with the recompute setting -r which is 0 (fastest, a lot of memory), 1 (very slightly slower, but a huge memory saving), or 2 (slightly slower, smaller memory saving). The next thing you can do is disable master weights in fp32, which you can do with -w 0 (1 is default). We won't maintain fp32 copy of params. Empirically in a few runs before this seems to be okay, likely due to our use of stochastic rounding. If even that doesn't fit (that's unlikely right?), you could try to decrease the maximum sequence length with -t, default is 1024 you can take it down to 512, 256, etc., but now you are making your model worse because you're decreasing its maximum attention span. Code. Certainly I feel biased but llm.c is quite beautiful: * It only requires basic CUDA dependencies to run. * It is a direct, minimal and readable implementation in C/CUDA. llm.c totals about 5,000 lines of C/CUDA code. We try to be mostly C, not C++ to keep it simple. Neural net training is just one while loop of the same, simple arithmetic operations (think +, -, *, /) on a single float array, it really shouldn't be that complicated. * It compiles and runs very quickly (few seconds), so you're doing more stepping and less waiting. * It allocates all of its GPU memory a single time at the start and from then on during training has an exactly constant memory footprint. So once you start stepping, you know you're good for the rest of the run and won't OOM. * It is bitwise deterministic. * It is efficient, at just below ~50% MFU. The main entry point and the majority of the code is in the file train_gpt2.cu. It contains the GPT-2 model definition and the training loop in ~2,000 LOC, and it imports a bunch of helper files with various utilities and the individual layer implementations from the llmc directory. cloc llmc reports 23 files with 3170 LOC, and cloc train_gpt2.cu is 1353 LOC atm. Multi-node training. If you are part of the privileged GPU-rich upper class, llm.c supports multi-node training and the most GPUs I've seen someone train llm.c with is ~500 GPUs. This biggest run I've done personally so far is on Lambda's new 1-click cluster feature with 16XH100 GPUs in 2 nodes. The downsides of unemployment. The lambda team has put up detailed instructions on how you can train llm.c models on their 1-click clusters. E.g. with the 512-GPU H100 cluster for $2,300/hr, you might be able to train your GPT-2 in ~30 minutes. You'd have to increase the total batch size (e.g. to ~8M) and possibly tune the hyperparameters a little. I haven't tried but it probably works and would be very cool :) PyTorch comparison. A relatively comparable run in PyTorch would I think look something like this, using our parallel PyTorch implementation: torchrun --standalone --nproc_per_node=8 train_gpt2.py \ --input_bin "dev/data/edu_fineweb100B/edu_fineweb_train_*.bin" \ --input_val_bin "dev/data/edu_fineweb100B/edu_fineweb_val_*.bin" \ --write_tensors 0 \ --model d48 \ --batch_size 8 --sequence_length 1024 --total_batch_size 1048576 \ --dtype bfloat16 \ --compile 1 \ --tensorcores 1 \ --flash 1 \ --num_iterations 32000 \ --warmup_iters 700 \ --weight_decay 0.1 \ --overfit_single_batch 0 \ --learning_rate 0.0006 \ --zero_stage 1 The PyTorch code is meant as a testing reference not an actual implementation, so the training loop is a little bit different in some places (e.g. the dataloader doesn't permute the shards, etc.), but this is still possibly useful as a point of reference. I also hacked the default vocab size to be 50257 -> 50304 to get added efficiency, then the currently PyTorch nightly gives: step 16/32000 | train loss 8.903997 | norm 8.3474 | lr 1.37e-05 | (3381.88 ms | 310057 tok/s) step 17/32000 | train loss 8.870140 | norm 3.7936 | lr 1.46e-05 | (3381.95 ms | 310051 tok/s) step 18/32000 | train loss 8.875732 | norm 9.4993 | lr 1.54e-05 | (3393.09 ms | 309033 tok/s) step 19/32000 | train loss 8.817432 | norm 2.8345 | lr 1.63e-05 | (3379.75 ms | 310253 tok/s) step 20/32000 | train loss 8.798056 | norm 4.1234 | lr 1.71e-05 | (3386.53 ms | 309631 tok/s) step 21/32000 | train loss 8.777574 | norm 2.8010 | lr 1.80e-05 | (3386.05 ms | 309675 tok/s) ... Now I wouldn't say I have full confidence that the PyTorch script is maximally tuned, but the following observations can be made. PyTorch seems to be taking a lot more memory (this run is ~80GB), while llm.c is at 57GB (29% improvement). Memory is important because it allows you to crank up the batch size (e.g. llm.c can go up to 24 microbatch here), which goes a bit faster. Second, we're seeing about 3386 vs. 2750ms per iteration, so llm.c is stepping ~19% faster. Some of the gains here have known origin, e.g. llm.c includes optimizations like the Fused classifier that kicks off the backward pass, which is something torch.compile does not do today afaik. But it's also possible that this script isn't fully maximally tuned, but in any case I'm showing the comparison in case 1) others would like to take a look, play with, compare, help tune and 2) to just say that llm.c is quite optimized and fast - in the specific case of GPT-2/3 training. The final model. A few links that may be helpful, for posterity: * The main.log file. * The model_00032000.bin llm.c bin model file * The model converted to huggingface transformers GPT-2 model I uploaded here: karpathy/gpt2_1558M_final2_hf. Model export. The model export can be done as follows, for example: python dev/eval/export_hf.py --input log_gpt2_128M/model_00032000.bin --output gpt2_1558M_export This then lets you run the Eleuther eval harness, or run the huggingface sampling pipeline to get model samples: # take model for spin import torch output = "./gpt2_1558M_final2_hf" # set pytorch seeds torch.manual_seed(42) torch.cuda.manual_seed(42) prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English." from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(output) model = AutoModelForCausalLM.from_pretrained(output, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map='cuda') model.eval() tokens = tokenizer.encode(prompt, return_tensors="pt") tokens = tokens.to('cuda') output = model.generate(tokens, max_new_tokens=500, pad_token_id=tokenizer.eos_token_id, do_sample=True, top_k=50, num_return_sequences=4) samples = tokenizer.batch_decode(output) for sample in samples: print('-'*30) print(sample) Also have a look at dev/eval for instructions on how to run the Eleuther Evaluation Harness, the evals from the HuggingFace Open LLM Leaderboard, etc. 400B token run. I have also made the attempt to train GPT-2 for significantly longer than 33B tokens. In particular, I changed -x to 400,000 to train for 420B tokens (even more than GPT-3 model of this size, which was trained with 300B). This model run looked great until about step 330,000: image This model dramatically beats GPT-2 and GPT-3 of its size on HellaSwag (it gets up to ~61%), but sadly becomes unstable there on and explodes. There are more smaller spikes along the way but the code is configured to detect the more simple instantaneous instability and skips update (I used the flags -sl 5.0 -sg 5.0), which helps mitigate and defers issues. However, I think we're not yet being sufficiently careful with our initialization, activation ranges, and overall model training stability and there are deeper issues that gradually drift the model into instability, especially for larger models and over long training duration. To be continued. If you have ideas or recommendations for stabilizing LLM model training please contribute your experience in the discussion below. FAQ: * Can I sample from the model in llm.c? kind of, but it's inefficient and a bit weird, and even more hacky if you'd like to prompt the model. Use the huggingface paths above for now. * Can I chat with it? no, this is currently only pretraining, not chat finetuning. * Can you train in fp8? No, we're currently mostly training in bf16, but early versions are very much work in progress. * I have a non-NVIDIA GPU can I run llm.c? No, llm.c supports C/ CUDA only, but good forks exist (see main README). For example there is an actively maintained AMD fork by @anthonix that is quite good. GPT-2 (124M). I wanted to also link to an earlier post on training the GPT-2 (124M) model in llm.c, which has some more related information to llm.c runs. 124M is a smaller model in the GPT-2 miniseries, only 124M parameters compared to 1558M parameters. Authors Substantial contributions to llm.c came from what now feels like the llm.c core dev team, in addition to self: * @ngc92 in all aspects of the code base * @ademeure in CUDA kernel optimization, low precision training, cudnn, cublas, ... * @gordicaleksa in all aspects of whatever is next on the TODO list, from algorithms to code to multi-node or etc. * @rosslwheeler in CI and Windows support. If you're happily running llm.c on Windows you should definitely thank Ross :) * Lambda labs for sponsoring the GPUs used in the development of llm.c. The history here is that I've happily used Lambda for several years and then a few months ago I pretty please asked if they are open to not charging my account for llm.c dev work and they agreed so here we are thank you for supporting llm.c! * Nvidia and Ubicloud (www.ubicloud.com) for providing the GitHub Nvidia GPU Runners for our CI. Coming up. Some of the next big steps we are interested in and looking at these days: 1. Further optimize GPT-2 training hyperparameters. For some reason, the hyperparameters cited by OpenAI in the GPT-3 paper appear to be quite suboptimal, e.g. @Yuchenj_UW on X found that you can 3X the learning rate and get faster training with no apparent downsides. There might be other similar low-hanging fruit. 2. Improve training and scaling stability, e.g. more stable optimizers, schedulers, clipping, norming, muP. (Some of these PRs already exist, if you have tips on stabilizing LLM runs please reach out with ideas to try!). 3. Mixed precision++: training with fp8 (imminent!). 4. Model inference, e.g. KV cache is the low hanging fruit here. 5. Finetuning: SFT, RLHF 6. Multimodal extensions, VQVAE and friends 7. More modern architectures, support for Llama / Gemma model series. The goal of llm.c remains to have a simple, minimal, clean training stack for a full-featured LLM agent, in direct C/CUDA, and companion educational materials to bring many people up to speed in this awesome field. Please feel free to use the Discussions for any FAQ and related, or if you'd like something faster, #llmc on Discord, or #llmdotc on CUDA MODE Discord. We'll see you next time! Beta Was this translation helpful? Give feedback. 25 You must be logged in to vote 5 [?] 40 1 All reactions * 5 * [?] 40 * 1 Replies: 6 comments * 6 replies * Oldest * Newest * Top Comment options * {{title}} Something went wrong. Quote reply [292] gordicaleksa Jul 11, 2024 - Next up "this is how to train Llama 3 8B in 72 hours for 1500$" Beta Was this translation helpful? Give feedback. 14 You must be logged in to vote All reactions 3 replies @YuchenJin Comment options * {{title}} Something went wrong. Quote reply YuchenJin Jul 11, 2024 - I also wonder how hard it is to modify the current codebase to train llama3 8B Beta Was this translation helpful? Give feedback. All reactions @ngc92 Comment options * {{title}} Something went wrong. Quote reply ngc92 Jul 11, 2024 - RMSnorm is fairly easy, removing biases is easy (both do have a PR already), SwiGLU should also be straightforwards, I think the main challenge will be group-query attention with rope-encoding. What is quite trivial is just scaling up the current model to 8B; in fact, I'm planning to make a PR that just adds a few more options for the model init. One question is how to continue the model series; two options would be: // deeper else if (depth == 60) { channels = 1920; num_heads = 30; } // 2.7B else if (depth == 72) { channels = 2880; num_heads = 30; } // 7.3B else if (depth == 84) { channels = 3456; num_heads = 36; } // 12.2B // wider else if (depth == 56) { channels = 1920; num_heads = 30; } // 2.6B else if (depth == 64) { channels = 2880; num_heads = 30; } // 6.5B else if (depth == 72) { channels = 3840; num_heads = 30; } // 12.9B This roughly matches the GPT3 series in paramter count, but both are much deeper. Beta Was this translation helpful? Give feedback. All reactions @YuchenJin Comment options * {{title}} Something went wrong. Quote reply YuchenJin Jul 11, 2024 - Supporting the whole GPT3 series would be interesting! Beta Was this translation helpful? Give feedback. All reactions Comment options * {{title}} Something went wrong. Quote reply [324] MathiasSchindler Jul 11, 2024 - 400B token run: "This model dramatically beats GPT-2 and GPT-3 of its size on HellaSwag (it gets up to ~61%), but sadly becomes unstable there on and explodes. " Would you be able to release the model right before the explosion? I would be interested to learn what instability and explosion look like in a model. Beta Was this translation helpful? Give feedback. 3 You must be logged in to vote All reactions 2 replies @karpathy Comment options * {{title}} Something went wrong. Quote reply karpathy Jul 11, 2024 Maintainer Author - Ok I'll upload it later today. I'm trying to revive it still and get the full 400B to complete. I took a close look at the weights last night and sadly I couldn't see any major issues. All of them look well behaved in good ranges, the worst I saw is the c_attn biases ranged -50 to 50, which seems rather broad. Everything else was in nice ranges. Beta Was this translation helpful? Give feedback. All reactions @michaelklachko Comment options * {{title}} Something went wrong. Quote reply michaelklachko Jul 11, 2024 - @karpathy I'd check for overflows in layernorm layers: if input x if FP16, variance might overflow: https://github.com/karpathy/llm.c/blob /master/doc/layernorm/layernorm.py#L12 Beta Was this translation helpful? Give feedback. All reactions Comment options * {{title}} Something went wrong. Quote reply [186] greydanus Jul 11, 2024 - I love the simplicity, power, and attention to detail. Well done! I hope to experiment with this code myself someday soon. Beta Was this translation helpful? Give feedback. 0 You must be logged in to vote All reactions 0 replies Comment options * {{title}} Something went wrong. Quote reply [400] ex3ndr Jul 11, 2024 - What about new flash attention 3? Will it slash price in half? Beta Was this translation helpful? Give feedback. 4 You must be logged in to vote All reactions 0 replies Comment options * {{title}} Something went wrong. Quote reply [119] jonready Jul 11, 2024 - Thanks again for your work to democratize AI. I'm only part way through makemore now, but am blown away by how simple you can make these tough topics. Beta Was this translation helpful? Give feedback. 3 You must be logged in to vote All reactions 0 replies Comment options * {{title}} Something went wrong. Quote reply [553] rosslwheeler Jul 11, 2024 - Also want to thank Ubicloud (www.ubicloud.com) for providing the GitHub Nvidia GPU Runners for CI and Nvidia for sponsoring this. Thank you!!! CC: @karpathy Beta Was this translation helpful? Give feedback. 2 You must be logged in to vote All reactions 1 reply @karpathy Comment options * {{title}} Something went wrong. Quote reply karpathy Jul 11, 2024 Maintainer Author - Bleh very good callout thank you @rosslwheeler and Ubicloud! Beta Was this translation helpful? Give feedback. 1 All reactions * 1 Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment Category General Labels None yet 10 participants @karpathy @ex3ndr @michaelklachko @MathiasSchindler @ngc92 @jonready @YuchenJin @greydanus @gordicaleksa @rosslwheeler Heading Bold Italic Quote Code Link --------------------------------------------------------------------- Numbered list Unordered list Task list --------------------------------------------------------------------- Attach files Mention Reference Menu * Heading * Bold * Italic * Quote * Code * Link * * Numbered list * Unordered list * Task list * * Attach files * Mention * Reference Select a reply Loading Create a new saved reply 1 reacted with thumbs up emoji 1 reacted with thumbs down emoji 1 reacted with laugh emoji 1 reacted with hooray emoji 1 reacted with confused emoji [?] 1 reacted with heart emoji 1 reacted with rocket emoji 1 reacted with eyes emoji Footer (c) 2024 GitHub, Inc. Footer navigation * Terms * Privacy * Security * Status * Docs * Contact * Manage cookies * Do not share my personal information You can't perform that action at this time.