https://github.com/microsoft/LoRA Skip to content Toggle navigation Sign up * Product + Actions Automate any workflow + Packages Host and manage packages + Security Find and fix vulnerabilities + Codespaces Instant dev environments + Copilot Write better code with AI + Code review Manage code changes + Issues Plan and track work + Discussions Collaborate outside of code + Explore + All features + Documentation + GitHub Skills + Blog * Solutions + For + Enterprise + Teams + Startups + Education + By Solution + CI/CD & Automation + DevOps + DevSecOps + Case Studies + Customer Stories + Resources * Open Source + GitHub Sponsors Fund open source developers + The ReadME Project GitHub community articles + Repositories + Topics + Trending + Collections * Pricing [ ] * # In this repository All GitHub | Jump to | * No suggested jump to results * # In this repository All GitHub | Jump to | * # In this organization All GitHub | Jump to | * # In this repository All GitHub | Jump to | Sign in Sign up {{ message }} microsoft / LoRA Public * Notifications * Fork 87 * Star 1.5k Code for loralib, an implementation of "LoRA: Low-Rank Adaptation of Large Language Models" arxiv.org/abs/2106.09685 License MIT license 1.5k stars 87 forks Star Notifications * Code * Issues 12 * Pull requests 15 * Discussions * Actions * Security * Insights More * Code * Issues * Pull requests * Discussions * Actions * Security * Insights microsoft/LoRA This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository. main Switch branches/tags [ ] Branches Tags Could not load branches Nothing to show {{ refName }} default View all branches Could not load tags Nothing to show {{ refName }} default View all tags Name already in use A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch? Cancel Create 17 branches 4 tags Code * Local * Codespaces * Clone HTTPS GitHub CLI [https://github.com/m] Use Git or checkout with SVN using the web URL. [gh repo clone micros] Work fast with our official CLI. Learn more. * Open with GitHub Desktop * Download ZIP Sign In Required Please sign in to use Codespaces. Launching GitHub Desktop If nothing happens, download GitHub Desktop and try again. Launching GitHub Desktop If nothing happens, download GitHub Desktop and try again. Launching Xcode If nothing happens, download Xcode and try again. Launching Visual Studio Code Your codespace will open once ready. There was a problem preparing your codespace, please try again. Latest commit @edwardjhu edwardjhu Update README.md ... e4a415e Mar 24, 2023 Update README.md e4a415e Git stats * 25 commits Files Permalink Failed to load latest commit information. Type Name Latest commit message Commit time examples Bump certifi in /examples/NLU/examples/research_projects/lxmert December 8, 2022 12:08 loralib Update layers.py August 30, 2022 11:28 .gitignore initial commit September 18, 2021 21:50 LICENSE.md Create LICENSE.md November 7, 2022 06:10 README.md Update README.md March 24, 2023 11:21 SECURITY.md Microsoft mandatory file July 25, 2022 19:08 setup.py initial commit September 18, 2021 21:50 View code [ ] LoRA: Low-Rank Adaptation of Large Language Models Repository Overview Quickstart Now training can proceed as usual. Additional Notes Contact Acknowledgements Citation Contributing README.md LoRA: Low-Rank Adaptation of Large Language Models (For the radio communication technique, see LoRa.) This repo contains the source code of the Python package loralib and several examples of how to integrate it with PyTorch models, such as those in HuggingFace. We only support PyTorch for now. See our paper for a detailed description of LoRA. LoRA: Low-Rank Adaptation of Large Language Models Edward J. Hu*, Yelong Shen*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen Paper: https://arxiv.org/abs/2106.09685 Update 2/2023: LoRA is now supported by the State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) library by HuggingFace. LoRA reduces the number of trainable parameters by learning pairs of rank-decompostion matrices while freezing the original weights. This vastly reduces the storage requirement for large language models adapted to specific tasks and enables efficient task-switching during deployment all without introducing inference latency. LoRA also outperforms several other adaptation methods including adapter, prefix-tuning, and fine-tuning. We obtain result comparable or superior to full finetuning on the GLUE benchmark using RoBERTa (Liu et al., 2019) base and large and DeBERTa (He et al., 2020) XXL 1.5B, while only training and storing a fraction of the parameters. Click the numbers below to download the RoBERTa and DeBERTa LoRA checkpoints. RoBERTa RoBERTa base DeBERTa DeBERTa XXL base LoRA XXL LoRA Fine-tune Fine-tune # of Trainable Params. 125M 0.8M 1.5B 4.7M MNLI (m-Acc/mm-Acc) 87.6 87.5+-.3/ 91.7/91.9 91.9+-.1/91.9 86.9+-.3 +-.2 SST2 (Acc) 94.8 95.1+-.2 97.2 96.9+-.2 MRPC (Acc) 90.2 89.7+-.7 92.0 92.6+-.6 CoLA (Matthew's Corr) 63.6 63.4+-1.2 72.0 72.4+-1.1 QNLI (Acc) 92.8 93.3+-.3 96.0 96.0+-.1 QQP (Acc) 91.9 90.8+-.1 92.7 92.9+-.1 RTE (Acc) 78.7 86.6+-.7 93.9 94.9+-.4 STSB (Pearson/Spearman 91.2 91.5+-.2/91.3 92.9/92.6 93.0+-.2/92.9 Corr) +-.2 +-.3 Average 86.40 87.24 91.06 91.32 Note: You still need the original pre-trained checkpoint from HuggingFace to use the LoRA checkpoints. Fine-tuning numbers are taken from Liu et al. (2019) and He et al. (2020). We include confidence intervals on results from our experiments. Please follow the instructions in examples/NLU/ to reproduce our results. On GPT-2, LoRA compares favorably to both full finetuning and other efficient tuning methods, such as adapter (Houlsby et al., 2019) and prefix tuning (Li and Liang, 2021). We evaluated on E2E NLG Challenge, DART, and WebNLG: Method # of Trainable E2E DART WebNLG (BLEU-U/S/ Params (BLEU) (BLEU) A) GPT-2 M 354.92M 68.2 46.0 30.4/63.2/47.6 (Fine-Tune) GPT-2 M 0.37M 66.3 42.4 45.1/54.5/50.2 (Adapter) GPT-2 M 0.35M 69.7 45.7 44.1/63.1/54.4 (Prefix) GPT-2 M (LoRA) 0.35M 70.4+-.1 47.1+-.2 46.7+-.4/62.1+-.2/ 55.3+-.2 GPT-2 L 774.03M 68.5 46.5 41.7/64.6/54.2 (Fine-Tune) GPT-2 L 0.88M 69.1+-.1 45.7+-.1 49.8+-.0/61.1+-.0/ (Adapter) 56.0+-.0 GPT-2 L 0.77M 70.3 46.5 47.0/64.2/56.4 (Prefix) GPT-2 L (LoRA) 0.77M 70.4+-.1 47.5+-.1 48.4+-.3/64.0+-.3/ 57.0+-.1 Non-LoRA baselines, except for adapter on GPT-2 large, are taken from Li and Liang (2021). We include confidence intervals on results from our experiments. Download the GPT-2 LoRA checkpoints: * GPT-2 Medium E2E (1.5 MB) * GPT-2 Medium DART (1.5 MB) * GPT-2 Medium WebNLG (1.5 MB) * GPT-2 Large E2E (2.3 MB) * GPT-2 Large DART (2.3 MB) * GPT-2 Large WebNLG (2.3 MB) Please follow the instructions in examples/NLG/ to reproduce our result. Repository Overview (The initial release of this repo has been archived in the branch "snapshot-9-15-2021") There are several directories in this repo: * loralib/ contains the source code for the package loralib, which needs to be installed to run the examples we provide; * examples/NLG/ contains an example implementation of LoRA in GPT-2 using our package, which can be used to reproduce the result in our paper; * examples/NLU/ contains an example implementation of LoRA in RoBERTa and DeBERTa using our package, which produces competitive results on the GLUE benchmark; * See how we use loralib in GPT-2, RoBERTa, and DeBERTa v2 Quickstart 1. Installing loralib is simply pip install loralib # Alternatively # pip install git+https://github.com/microsoft/LoRA 2. You can choose to adapt some layers by replacing them with counterparts implemented in loralib. We only support nn.Linear, nn.Embedding, and nn.Conv2d for now. We also support a MergedLinear for cases where a single nn.Linear represents more than one layers, such as in some implementations of the attention qkv projection (see Additional Notes for more). # ===== Before ===== # layer = nn.Linear(in_features, out_features) # ===== After ====== import loralib as lora # Add a pair of low-rank adaptation matrices with rank r=16 layer = lora.Linear(in_features, out_features, r=16) 3. Before the training loop begins, mark only LoRA parameters as trainable. import loralib as lora model = BigModel() # This sets requires_grad to False for all parameters without the string "lora_" in their names lora.mark_only_lora_as_trainable(model) # Training loop for batch in dataloader: ... 4. When saving a checkpoint, generate a state_dict that only contains LoRA parameters. # ===== Before ===== # torch.save(model.state_dict(), checkpoint_path) # ===== After ===== torch.save(lora.lora_state_dict(model), checkpoint_path) 5. When loading a checkpoint using load_state_dict, be sure to set strict=False. # Load the pretrained checkpoint first model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False) # Then load the LoRA checkpoint model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False) Now training can proceed as usual. Additional Notes 1. While we focus on a simple yet effect setup, namely adapting only the q and v projection in a Transformer, in our examples, LoRA can be apply to any subsets of pre-trained weights. We encourage you to explore different configurations, such as adapting the embedding layer by replacing nn.Embedding with lora.Embedding and /or adapting the MLP layers. It's very likely that the optimal configuration varies for different model architectures and tasks. 2. Some Transformer implementation uses a single nn.Linear for the projection matrices for query, key, and value. If one wishes to constrain the rank of the updates to the individual matrices, one has to either break it up into three separate matrices or use lora.MergedLinear. Make sure to modify the checkpoint accordingly if you choose to break up the layer. # ===== Before ===== # qkv_proj = nn.Linear(d_model, 3*d_model) # ===== After ===== # Break it up (remember to modify the pretrained checkpoint accordingly) q_proj = lora.Linear(d_model, d_model, r=8) k_proj = nn.Linear(d_model, d_model) v_proj = lora.Linear(d_model, d_model, r=8) # Alternatively, use lora.MergedLinear (recommended) qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True]) 3. Training bias vectors in tandem with LoRA might be a cost-efficient way to squeeze out extra task performance (if you tune the learning rate carefully). While we did not study its effect thoroughly in our paper, we make it easy to try in lora. You can mark some biases as trainable by passing "all" or "lora_only" to bias= when calling mark_only_lora_as_trainable. Remember to pass the corresponding bias= argument to lora_state_dict when saving a checkpoint. # ===== Before ===== # lora.mark_only_lora_as_trainable(model) # Not training any bias vectors # ===== After ===== # Training all bias vectors associated with modules we apply LoRA to lora.mark_only_lora_as_trainable(model, bias='lora_only') # Alternatively, we can train *all* bias vectors in the model, including LayerNorm biases lora.mark_only_lora_as_trainable(model, bias='all') # When saving a checkpoint, use the same bias= ('all' or 'lora_only') torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path) 4. Calling model.eval() will trigger the merging of LoRA parameters with the corresponding pretrained ones, which eliminates additional latency for subsequent forward passes. Calling model.train() again will undo the merge. This can be disabled by passing merge_weights=False to LoRA layers. Contact Please contact us or post an issue if you have any questions. For questions related to the package loralib: * Edward Hu (edward@edwardjhu.com) * Phillip Wallis (phwallis@microsoft.com) * Weizhu Chen (wzchen@microsoft.com) The GPT-2 example: * Phillip Wallis (phwallis@microsoft.com) * Yelong Shen (yeshe@microsoft.com) The RoBERTa/DeBERTa example: * Lu Wang (luw@microsoft.com) Acknowledgements We thank in alphabetical order Jianfeng Gao, Jade Huang, Jiayuan Huang, Lisa Xiang Li, Xiaodong Liu, Yabin Liu, Benjamin Van Durme, Luis Vargas, Haoran Wei, Peter Welinder, and Greg Yang for providing valuable feedback. Citation @misc{hu2021lora, title={LoRA: Low-Rank Adaptation of Large Language Models}, author={Hu, Edward and Shen, Yelong and Wallis, Phil and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Lu and Chen, Weizhu}, year={2021}, eprint={2106.09685}, archivePrefix={arXiv}, primaryClass={cs.CL} } Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https:// cla.opensource.microsoft.com. When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments. About Code for loralib, an implementation of "LoRA: Low-Rank Adaptation of Large Language Models" arxiv.org/abs/2106.09685 Topics deep-learning pytorch language-model adaptation roberta low-rank gpt-2 gpt-3 deberta Resources Readme License MIT license Code of conduct Code of conduct Security policy Security policy Stars 1.5k stars Watchers 17 watching Forks 87 forks Releases 4 tags Used by 81 * @vojtsek * @Eli6464 * @arcs002 * @Facico * @GetSwype * @if001 * @alexrame * @lc222 + 73 Contributors 6 * @edwardjhu * @msft-edward * @microsoftopensource * @dependabot[bot] * @luw315 * @microsoft-github-policy-service[bot] Languages * Python 100.0% Footer (c) 2023 GitHub, Inc. Footer navigation * Terms * Privacy * Security * Status * Docs * Contact GitHub * Pricing * API * Training * Blog * About You can't perform that action at this time. You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.