[HN Gopher] Show HN: Tune LLaMa3.1 on Google Cloud TPUs
___________________________________________________________________
Show HN: Tune LLaMa3.1 on Google Cloud TPUs
Hey HN, we wanted to share our repo where we fine-tuned Llama 3.1
on Google TPUs. We're building AI infra to fine-tune and serve LLMs
on non-NVIDIA GPUs (TPUs, Trainium, AMD GPUs). The problem: Right
now, 90% of LLM workloads run on NVIDIA GPUs, but there are equally
powerful and more cost-effective alternatives out there. For
example, training and serving Llama 3.1 on Google TPUs is about 30%
cheaper than NVIDIA GPUs. But developer tooling for non-NVIDIA
chipsets is lacking. We felt this pain ourselves. We initially
tried using PyTorch XLA to train Llama 3.1 on TPUs, but it was
rough: xla integration with pytorch is clunky, missing libraries
(bitsandbytes didn't work), and cryptic HuggingFace errors. We
then took a different route and translated Llama 3.1 from PyTorch
to JAX. Now, it's running smoothly on TPUs! We still have
challenges ahead, there is no good LoRA library in JAX, but this
feels like the right path forward. Here's a demo
(https://dub.sh/felafax-demo) of our managed solution. Would love
your thoughts on our repo and vision as we keep chugging along!
Author : felarof
Score : 77 points
Date : 2024-09-11 15:14 UTC (7 hours ago)
(HTM) web link (github.com)
(TXT) w3m dump (github.com)
| mandoline wrote:
| Do you have any apples-to-apples speed and cost comparisons
| across Nvidia vs. non-NVIDIA chips (as you mentioned: TPUs,
| Trainium, AMD GPUs)?
| felarof wrote:
| Google published this benchmark a year or so ago comparing TPU
| vs NVIDIA (https://github.com/GoogleCloudPlatform/vertex-ai-
| samples/blo...)
|
| Conclusion is at the bottom, but TLDR was TPUs were 33% cheaper
| (performance per dollar) and JAX scales very well compared to
| PyTorch.
|
| If you are curious, there was a thorough comparison done by
| Cohere and they published their paper
| https://arxiv.org/pdf/2309.07181 -- TPU+JAX turned out to be
| more performant and more fault tolerant (less weird errors).
| htrp wrote:
| What was the estimate for how much time you guys took to
| translate the torch to Jax vs how much you spent on XLA?
| felarof wrote:
| It took roughly 2-3 weeks to translate Torch to JAX, but I had
| past experience writing JAX from my time at Google.
|
| We spent nearly 4 weeks getting PyTorch XLA working on TPU.
| Hope that answers your question!
| ricw wrote:
| I'm surprised how it's only 30% cheaper vs nvidia. How come? This
| seems to indicate that the nvidia premium isn't as high as
| everybody makes it out to be.
| cherioo wrote:
| Nvidia margin is like 70%. Using google TPU is certainly going
| to erase some of that.
| felarof wrote:
| 30% is a conservative estimate (to be precise, we went with
| this benchmark: https://github.com/GoogleCloudPlatform/vertex-
| ai-samples/blo...). However, the actual difference we observe
| ranges from 30-70%.
|
| Also, calculating GPU costs is getting quite nuanced, with a
| wide range of prices (https://cloud-gpus.com/) and other
| variables that makes it harder to do apples-to-apples
| comparison.
| p1esk wrote:
| Did you try running this task (finetuning Llama) on Nvidia
| GPUs? If yes, can you provide details (which cloud instance
| and time)?
|
| I'm curious about your reported 30-70% speedup.
| felarof wrote:
| I think you slightly misunderstood, and I wasn't clear
| enough--sorry! It's not a 30-70% speedup; it's 30-70% more
| cost-efficient. This is mainly due to non-NVIDIA chipsets
| (e.g., Google TPU) being cheaper, with some additional
| efficiency gains from JAX being more closely integrated
| with the XLA architecture.
|
| No, we haven't run our JAX + XLA on NVIDIA chipsets yet.
| I'm not sure if NVIDIA has good XLA backend support.
| p1esk wrote:
| Then how did you compute the 30-70% cost efficiency
| numbers compared to Nvidia if you haven't run this Llama
| finetuning task on Nvidia GPUs?
| felarof wrote:
| Check out this benchmark where they did an analysis:
| https://github.com/GoogleCloudPlatform/vertex-ai-
| samples/blo....
|
| At the bottom, it shows the calculations around the 30%
| cost efficiency of TPU vs GPU.
|
| Our range of 30-70% is based on some numbers we collected
| from running fine-tuning runs on TPU and comparing them
| to similar runs on NVIDIA (though not using our code but
| other OSS libraries).
| p1esk wrote:
| It would be a lot more convincing if you actually ran it
| yourself and did a proper apples to apples comparison,
| especially considering that's the whole idea behind your
| project.
| felarof wrote:
| Totally agree, thanks for feedback! This is one of the
| TODOs on our radar.
| khimaros wrote:
| an interesting thread with speculation about how to eventually do
| this on local TPUs with llama.cpp and GGUF infrastructure:
| https://www.reddit.com/r/LocalLLaMA/comments/12o96hf/has_any...
| felarof wrote:
| Ahh, the reddit thread is referring to edge TPU devices, will
| check it out.
|
| Google also has Cloud TPUs, which are their server-side
| accelerators, and this is what we are initially trying to build
| for!
| axpy906 wrote:
| I am actually not surprised by JAX converting better to XLA. Also
| deep respect for anybody in this space as their is lot of
| complexity (?) to deal with at the framework and compiler level.
| felarof wrote:
| Thank you! Yeah, there are a few complexities and very little
| documentation around JAX, plus a lot of missing libraries.
| stroupwaffle wrote:
| You might want to change Road Runner logo because it's definitely
| copyrighted
| felarof wrote:
| Haha, yeah, good point. I'll remove it.
| xrd wrote:
| Anyone want to comment on this versus the fine tune speedups from
| llama3.1 with unsloth?
| felarof wrote:
| Unsloth is great! They focus on single-GPU and LoRA fine-tuning
| on NVIDIA GPUs. We are initially trying to target multi-node,
| multi-TPU, full-precision training use cases.
|
| That said, in terms of single-GPU speed, we believe we would be
| behind but not too far off, thanks to JAX+TPU's more performant
| stack. Additionally, we can do larger-scale multi-node training
| on TPUs.
|
| There are still more optimizations we need to do for Llama 3.1,
| such as adding Pallas memory attention kernels, etc
| tcdent wrote:
| Where in the codebase is the logic specific to TPU vs. CUDA?
___________________________________________________________________
(page generated 2024-09-11 23:00 UTC)