[HN Gopher] We fine-tuned Llama 405B on AMD GPUs
       ___________________________________________________________________
        
       We fine-tuned Llama 405B on AMD GPUs
        
       Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD
       MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding
       APIs allowed us to achieve great performance. Check out our blog
       post to learn about the cool sharding tricks we used. We've also
       open-sourced the code: https://github.com/felafax/felafax  We're a
       small startup building AI infra for fine-tuning and serving LLMs on
       non-NVIDIA hardware (TPUs, AMD, Trainium).  Problem: Many companies
       are trying to get PyTorch working on AMD GPUs, but we believe this
       is a treacherous path. PyTorch is deeply intertwined with the
       NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or
       scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a
       PyTorch function). So, to get PyTorch code running on non-NVIDIA
       hardware, there's a lot of "de-NVIDIAfying" that needs to be done.
       Solution: We believe JAX is a better fit for non-NVIDIA hardware.
       In JAX, ML model code compiles to hardware-independent HLO graphs,
       which are then optimized by the XLA compiler before hardware-
       specific optimization. This clean separation allowed us to run the
       same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no
       changes.  Our strategy as a company is to invest upfront in porting
       models to JAX, then leverage its framework and XLA kernels to
       extract maximum performance from non-NVIDIA backends. This is why
       we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX
       model works great on TPUs and runs perfectly on AMD GPUs.  We'd
       love to hear your thoughts on our vision and repo!
        
       Author : felarof
       Score  : 75 points
       Date   : 2024-09-23 21:42 UTC (1 hours ago)
        
 (HTM) web link (publish.obsidian.md)
 (TXT) w3m dump (publish.obsidian.md)
        
       | felarof wrote:
       | Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD
       | MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding
       | APIs allowed us to achieve great performance. Check out our blog
       | post to learn about the cool sharding tricks we used. We've also
       | open-sourced the code: https://github.com/felafax/felafax
       | 
       | We're a small startup building AI infra for fine-tuning and
       | serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).
       | 
       | Problem: Many companies are trying to get PyTorch working on AMD
       | GPUs, but we believe this is a treacherous path. PyTorch is
       | deeply intertwined with the NVIDIA ecosystem in a lot of ways
       | (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA
       | CUDA kernel exposed as a PyTorch function). So, to get PyTorch
       | code running on non-NVIDIA hardware, there's a lot of "de-
       | NVIDIAfying" that needs to be done.
       | 
       | Solution: We believe JAX is a better fit for non-NVIDIA hardware.
       | In JAX, ML model code compiles to hardware-independent HLO
       | graphs, which are then optimized by the XLA compiler before
       | hardware-specific optimization. This clean separation allowed us
       | to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs
       | with no changes.
       | 
       | Our strategy as a company is to invest upfront in porting models
       | to JAX, then leverage its framework and XLA kernels to extract
       | maximum performance from non-NVIDIA backends. This is why we
       | first ported Llama 3.1 from PyTorch to JAX, and now the same JAX
       | model works great on TPUs and runs perfectly on AMD GPUs.
       | 
       | We'd love to hear your thoughts on our vision and repo!
        
         | ngcc_hk wrote:
         | Given it is a migration, is there actual comparison of the same
         | model on PyTorch vs your version. The comparison table there
         | seems to be on technical side.
         | 
         | Also any technical issues encountered?
        
         | jgalt212 wrote:
         | Is there some cost rule of thumb to compare Nvidia, AMD, and
         | Google TPU?
        
       | latchkey wrote:
       | Nice work! I was just playing with the inference side of things
       | with 405B myself this weekend [0].
       | 
       | I'm not convinced that 'torch.cuda' is really that bad since the
       | AMD version of PyTorch just translates that for you. More like a
       | naming problem, than anything. Fact is that it is just as easy to
       | grab the rocm:pytorch container, as it is the rocm:jax container.
       | 
       | I don't see very many numbers posted. What MFU did you get?
       | 
       | [0] https://x.com/HotAisle/status/1837580046732874026
        
       | abalaji wrote:
       | @dang: could we get url to include the username since this isn't
       | about Obsidian itself, but rather a user generated blog?
        
       | 3abiton wrote:
       | Firstly great work! I dabbled with AMD GPUs and ROCm support a
       | year ago, and it was obvious AMD still a long way from catch ling
       | up with Nvidia. While opting for JAX is in an interesting
       | approach, what were the challenges for you deviating from pytorch
       | (being the standard library for ML)?
        
         | 6y56h56 wrote:
         | I cannot get AMD ROCm running on my debian 12 system which is
         | what I think is causing Ollama to use CPU instead of GPU. So I
         | guess there is still a long way to go.
        
           | ants_everywhere wrote:
           | I've had more luck with the ROCm docker container. I run it
           | via k8s. It was pretty painless to set up and has been mostly
           | painless since. Prior to that it was nearly impossible to get
           | Jax running reliably on ROCm.
           | 
           | Even with the container, you have to be careful installing
           | Python libraries because they can still break things.
        
       | yeahwhatever10 wrote:
       | Where is the performance data?
        
       | manojlds wrote:
       | Thought this was a post from Obsidian at first. Why haven't they
       | done the GitHub.com vs GitHub.io thing yet.
        
       ___________________________________________________________________
       (page generated 2024-09-23 23:00 UTC)