https://github.com/NX-AI/xlstm Skip to content Navigation Menu Toggle navigation Sign in * Product + Actions Automate any workflow + Packages Host and manage packages + Security Find and fix vulnerabilities + Codespaces Instant dev environments + GitHub Copilot Write better code with AI + Code review Manage code changes + Issues Plan and track work + Discussions Collaborate outside of code Explore + All features + Documentation + GitHub Skills + Blog * Solutions For + Enterprise + Teams + Startups + Education By Solution + CI/CD & Automation + DevOps + DevSecOps Resources + Learning Pathways + White papers, Ebooks, Webinars + Customer Stories + Partners * Open Source + GitHub Sponsors Fund open source developers + The ReadME Project GitHub community articles Repositories + Topics + Trending + Collections * Enterprise + Enterprise platform AI-powered developer platform Available add-ons + Advanced Security Enterprise-grade security features + GitHub Copilot Enterprise-grade AI features + Premium Support Enterprise-grade 24/7 support * Pricing Search or jump to... Search code, repositories, users, issues, pull requests... Search [ ] Clear Search syntax tips Provide feedback We read every piece of feedback, and take your input very seriously. [ ] [ ] Include my email address so I can be contacted Cancel Submit feedback Saved searches Use saved searches to filter your results more quickly Name [ ] Query [ ] To see all available qualifiers, see our documentation. Cancel Create saved search Sign in Sign up You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert {{ message }} NX-AI / xlstm Public * Notifications You must be signed in to change notification settings * Fork 34 * Star 497 * Official repository of the xLSTM. License AGPL-3.0 license 497 stars 34 forks Branches Tags Activity Star Notifications You must be signed in to change notification settings * Code * Issues 6 * Pull requests 0 * Actions * Projects 0 * Security * Insights Additional navigation options * Code * Issues * Pull requests * Actions * Projects * Security * Insights NX-AI/xlstm This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository. main BranchesTags Go to file Code Folders and files Last Last Name Name commit commit message date Latest commit History 10 Commits experiments experiments res res test test xlstm xlstm .gitignore .gitignore AUTHORS AUTHORS LICENSE LICENSE README.md README.md __init__.py __init__.py environment_pt220cu121.yaml environment_pt220cu121.yaml pyproject.toml pyproject.toml setup.cfg setup.cfg View all files Repository files navigation * README * AGPL-3.0 license xLSTM: Extended Long Short-Term Memory xLSTM Figure Paper: https://arxiv.org/abs/2405.04517 About xLSTM is a new Recurrent Neural Network architecture based on ideas of the original LSTM. Through Exponential Gating with appropriate normalization and stabilization techniques and a new Matrix Memory it overcomes the limitations of the original LSTM and shows promising performance on Language Modeling when compared to Transformers or State Space Models. Minimal Installation Create a conda environment from the file environment_pt220cu121.yaml. Install the model code only (i.e. the module xlstm) as package: Instally via pip: pip install xlstm Clone from github: git clone https://github.com/NX-AI/xlstm.git cd xlstm pip install -e . Requirements This package is based on PyTorch and was tested for versions >=1.8. For the CUDA version of sLSTM, you need Compute Capability >= 8.0, see https://developer.nvidia.com/cuda-gpus. For a well-tested environment, install the environment_pt220cu121.yaml as: conda env create -n xlstm -f environment_pt220cu121.yaml conda activate xlstm Usage For non language applications or for integrating in other architectures you can use the xLSTMBlockStack and for language modeling or other token-based applications you can use the xLSTMLMModel. xLSTM Block Stack The xLSTMBLockStack is meant for use as alternative backbone in existing projects. It is similar to a stack of Transformer blocks, but uses xLSTM blocks: import torch from xlstm import ( xLSTMBlockStack, xLSTMBlockStackConfig, mLSTMBlockConfig, mLSTMLayerConfig, sLSTMBlockConfig, sLSTMLayerConfig, FeedForwardConfig, ) cfg = xLSTMBlockStackConfig( mlstm_block=mLSTMBlockConfig( mlstm=mLSTMLayerConfig( conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4 ) ), slstm_block=sLSTMBlockConfig( slstm=sLSTMLayerConfig( backend="cuda", num_heads=4, conv1d_kernel_size=4, bias_init="powerlaw_blockdependent", ), feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"), ), context_length=256, num_blocks=7, embedding_dim=128, slstm_at=[1], ) xlstm_stack = xLSTMBlockStack(cfg) x = torch.randn(4, 256, 128).to("cuda") xlstm_stack = xlstm_stack.to("cuda") y = xlstm_stack(x) y.shape == x.shape If you are working with yaml strings / files for configuration you can also use dacite to create the config dataclasses. This is the same as the snippet above: from omegaconf import OmegaConf from dacite import from_dict from dacite import Config as DaciteConfig from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig xlstm_cfg = """ mlstm_block: mlstm: conv1d_kernel_size: 4 qkv_proj_blocksize: 4 num_heads: 4 slstm_block: slstm: backend: cuda num_heads: 4 conv1d_kernel_size: 4 bias_init: powerlaw_blockdependent feedforward: proj_factor: 1.3 act_fn: gelu context_length: 256 num_blocks: 7 embedding_dim: 128 slstm_at: [1] """ cfg = OmegaConf.create(xlstm_cfg) cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True)) xlstm_stack = xLSTMBlockStack(cfg) x = torch.randn(4, 256, 128).to("cuda") xlstm_stack = xlstm_stack.to("cuda") y = xlstm_stack(x) y.shape == x.shape xLSTM Language Model The xLSTMLMModel is a wrapper around the xLSTMBlockStack that adds the token embedding and lm head. from omegaconf import OmegaConf from dacite import from_dict from dacite import Config as DaciteConfig from xlstm import xLSTMLMModel, xLSTMLMModelConfig xlstm_cfg = """ vocab_size: 50304 mlstm_block: mlstm: conv1d_kernel_size: 4 qkv_proj_blocksize: 4 num_heads: 4 slstm_block: slstm: backend: cuda num_heads: 4 conv1d_kernel_size: 4 bias_init: powerlaw_blockdependent feedforward: proj_factor: 1.3 act_fn: gelu context_length: 256 num_blocks: 7 embedding_dim: 128 slstm_at: [1] """ cfg = OmegaConf.create(xlstm_cfg) cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True)) xlstm_stack = xLSTMLMModel(cfg) x = torch.randint(0, 50304, size=(4, 256)).to("cuda") xlstm_stack = xlstm_stack.to("cuda") y = xlstm_stack(x) y.shape[1:] == (256, 50304) Experiments The synthetic experiments show-casing the benefits of sLSTM over mLSTM and vice versa best are the Parity task and the Multi-Query Associative Recall task. The Parity task can only be solved with state-tracking capabilities provided by the memory-mixing of sLSTM. The Multi-Query Associative Recall task measures memorization capabilities, where the matrix-memory and state expansion of mLSTM is very beneficial. In combination they do well on both tasks. To run each, run the main.py in the experiments folder like: python experiments/main.py --config parity_xLSTM01.yaml # xLSTM[0:1], sLSTM only python experiments/main.py --config parity_xLSTM10.yaml # xLSTM[1:0], mLSTM only python experiments/main.py --config parity_xLSTM11.yaml # xLSTM[1:1], mLSTM and sLSTM Note that the training loop does not contain early stopping or test evaluation. Citation If you use this codebase, or otherwise find our work valuable, pleace cite the xLSTM paper: @article{xlstm, title={xLSTM: Extended Long Short-Term Memory}, author={Beck, Maximilian and P{\"o}ppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, G{\"u}nter and Brandstetter, Johannes and Hochreiter, Sepp}, journal={arXiv preprint arXiv:2405.04517}, year={2024} } About Official repository of the xLSTM. Topics nlp machine-learning deep-learning rnn llm deep-learning-architecture Resources Readme License AGPL-3.0 license Activity Custom properties Stars 497 stars Watchers 8 watching Forks 34 forks Report repository Releases 2 tags Packages 0 No packages published Languages * Python 39.2% * Cuda 35.0% * C++ 13.0% * Jupyter Notebook 12.3% * C 0.5% Footer (c) 2024 GitHub, Inc. Footer navigation * Terms * Privacy * Security * Status * Docs * Contact * Manage cookies * Do not share my personal information You can't perform that action at this time.