[HN Gopher] Show HN: How does JAX allocate memory on a TPU? An i...
       ___________________________________________________________________
        
       Show HN: How does JAX allocate memory on a TPU? An interactive C++
       walkthrough
        
       Author : sillysaurusx
       Score  : 70 points
       Date   : 2021-11-06 10:16 UTC (12 hours ago)
        
 (HTM) web link (gist.github.com)
 (TXT) w3m dump (gist.github.com)
        
       | alevskaya wrote:
       | One thing that's important to note about memory management with
       | XLA, is that inside a compiled program there's no user-exposed
       | "malloc"/"free". The memory usage and schedule of a given
       | program/graph is statically optimized by the compiler (thus JAX
       | requires static shapes when jitting). When running in op-by-
       | op/eager mode the allocated buffers are coupled to the lifetime
       | of the python array object, and are free'd when this array is
       | garbage-collected.
        
         | tubby12345 wrote:
         | >The memory usage and schedule of a given program/graph is
         | statically optimized by the compiler
         | 
         | Is it really though? The only thing I see is
         | 
         | https://github.com/tensorflow/tensorflow/blob/95cdeaa8c848fd...
         | 
         | which traces back to
         | 
         | https://github.com/tensorflow/tensorflow/blob/54a8a3b373918b...
         | 
         | which doesn't anything smart that i can tell.
        
       | azalemeth wrote:
       | At the risk of sounding like a HN stereotype, for those who don't
       | know what a TPU is:
       | 
       | > TPUs are hardware accelerators specialized in deep learning
       | tasks. They are supported in Tensorflow 2.1 both through the
       | Keras high-level API, and at a lower level.
        
         | sillysaurusx wrote:
         | They're pretty rad. You can even get a few (dozen) for free:
         | https://blog.gpt4.org/jaxtpu
         | 
         | Incidentally, that blog runs on a TPU. You can just SSH into
         | them.
         | 
         | Stuff like that is why I think Jax will beat pytorch in the
         | long run: https://blog.gpt4.org/mlmind
         | 
         | A couple years from now, Metabook might find themselves in real
         | trouble. They won't be the React of ML anymore. All the
         | researchers who want to get work done have already flocked to
         | Jax, so I suspect it's a matter of time till the rest of the
         | world notices.
         | 
         | Time will tell. For now, it's by far the most fun I've had in
         | my entire career. I've been in ML for awhile now, but my
         | background was gamedev, finance, and security -- I encourage
         | you to dip your toe into the weird ML scene, because it happens
         | to be a blast nowadays.
        
           | boibombeiro wrote:
           | Jax uses XLA (a IR that targets multiple platforms, including
           | the Google TPU) as backend. Pytorch also has a XLA backend.
        
             | carbocation wrote:
             | Underscoring your point, GCP has a tutorial page for
             | training Pytorch models on TPU:
             | https://cloud.google.com/tpu/docs/tutorials/pytorch-pod
        
             | sillysaurusx wrote:
             | A terrible one. It's literally one of the worst experiences
             | you can have. I say that without an ounce of bias.
             | 
             | When tpu SSH first came out, I immediately went to my
             | pytorch buddy (who became famous in the meantime by
             | pioneering the CLIP AI art you've probably been seeing:
             | https://mobile.twitter.com/rivershavewings) and said
             | "Rivers, you have to try TPUs! They're wonderful! Now you
             | can run pytorch on them directly!"
             | 
             | She was skeptical, because she'd had nothing but endless
             | problems with the official pytorch TPU notebooks (which use
             | a TPU in "remote" mode, aka it fires up an RPC server you
             | attach to).
             | 
             | "No no, trust me -- I bet their rpc driver was just
             | horrible. It can't be true that the Pytorch XLA backend is
             | terrible on TPUs. How could that be? Facebook makes awesome
             | software, and they've worked on this for so long. Some
             | intern probably wrote the rpc system or something. Anyway,
             | you can SSH into mine! See, no RPC server! What's your
             | pubkey?"
             | 
             | A few hours later, she reported that it froze in exactly
             | the same place.
             | 
             | I could hardly believe it. I was so upset that I dug really
             | deeply to try to figure out what the heck the problem was.
             | 
             | It was recompiling itself. Every. Inference.
             | 
             | Every compile took 15 minutes.
             | 
             | Their XLA backend is so far behind Jax that it's not even a
             | contest.
             | 
             | Worse, I've since realized _why_ they can't fix the
             | problem. JAX gives you precise control over which functions
             | get JIT'ed (compiled into XLA), and _when_ that happens.
             | 
             | Pytorch doesn't. There's no @torch.jit decorator for your
             | functions.
             | 
             | That means they need to infer which parts to jit (the
             | inference model), and which parts not to (the inference
             | loop).
             | 
             | The moment that magic fails, congrats; it's now recompiling
             | every iteration. (15 min per iteration is impressive, since
             | usually it's measured in "iterations per second"...)
             | 
             | JAX even has a TPU compiler cache now, which persists
             | across Python runs. It's like MacOS in 2021 vs Windows in
             | 1995.
             | 
             | This is deeply frustrating to me, because there's no reason
             | for it -- I don't know if Facebook is intentionally
             | kneecapping pytorch to keep people off of TPUs, or if it's
             | just run of the mill incompetence. But (like Windows 95)
             | they're going to discover they're not the only game in town
             | forever.
        
               | bertr4nd wrote:
               | The pytorch programming model is just really hard to
               | adapt to an XLA-like compiler. Imperative python code
               | doesn't translate to an ML graph compiler particularly
               | well; Jax's API is functional, so it's easier to
               | translate to the XLA API. By contrast, torch/xla uses
               | "lazy tensors" that record the computation graph and
               | compile when needed. The trouble is, if the compute graph
               | changes from run to run, you end up recompiling a lot.
               | 
               | I guess in Jax you'd just only apply `jax.jit` to the
               | parts where the compute graph is static? I'd be curious
               | to see examples of how this works in practice. Fwiw,
               | there's an offshoot of pytorch that is aiming to provide
               | this sort of API (see
               | https://github.com/pytorch/functorch and look at
               | eager_compilation.py).
               | 
               | (Disclaimer: I worked on this until quite recently.)
        
           | rrss wrote:
           | > that blog runs on a TPU. You can just SSH into them.
           | 
           | By "runs on a TPU," you mean it runs on a CPU in a VM
           | somewhere that also has access to TPU hardware, right?
           | 
           | If this is the new TPU phraseology, that's pretty confusing
           | IMO.
        
             | sillysaurusx wrote:
             | You're right, I spoke carelessly. The proper term is that
             | the blog is running on a "TPU VM", which is exactly what
             | you describe: a box with /dev/accel0 that libtpu.so uses to
             | communicate directly with the TPU hardware.
             | 
             | The difference is, every TPU VM has 96 cpu cores and 350GB
             | of RAM. (Pods only get 48 cores per host, but they have 1
             | host per 8 TPU cores, and the smallest pod has 32 TPU
             | cores, for a whopping 1.2TB of RAM and 196 CPU cores.)
             | 
             | Which is to say, I still think of them as "a TPU", because
             | nowhere in the world have I ever been able to access that
             | amount of raw horsepower. Not on x86_64 Ubuntu that you can
             | pip install things on, at least. Like a blog. :)
             | 
             | Wanna see a magic trick? Clone tensorflow onto a TPU VM and
             | start building it. htop will light up like a Christmas tree
             | (https://twitter.com/theshawwn/status/1400771262901854214)
             | and it'll finish in about 25 minutes flat, if I remember
             | correctly.
             | 
             | So yeah, Google is throwing around compute like Israel
             | doling out vacations to Tel Aviv for distant relatives:
             | it's totally free. Partake!
             | 
             | (I'm really looking forward to seeing Israel someday. I
             | never realized how beautiful the beaches are...)
             | 
             | 96 core TPU VMs, free as in beer! It's so exciting that I
             | just can't shut up about it.
             | 
             | TRC gives you access to 100 separate VMs the moment you
             | sign up.
             | 
             | Having access to 100 VM-yachts totally rules. SSHing into
             | 100 of them feels like commanding a WW2 carrier division,
             | or something.
             | 
             | It's quite literally too fun: I have to force myself not to
             | spend all day unlocking their secrets. There's so much new
             | territory to explore -- every day feels like a tremendous
             | adventure. Our computing ancestors could only dream of
             | exploiting as much hardware as our M1's take for granted,
             | let alone one of these behemoths. Let alone one _hundred_!
             | 
             | I went back to look at my old notes. Christmas in 2019 was
             | magical, because in January of 2020 I managed to fire up
             | santa's 300 TPUs, while a colleague fired up santa's other
             | 100 TPUs. Then we swarmed them together into a tornado of
             | compute so powerful that even _connecting_ to all 400 TPUs
             | required special configuration settings ( "too many open
             | files" aka sockets):
             | https://twitter.com/theshawwn/status/1221241517626445826
             | 
             | We were training models so fast that I bet even Goku in the
             | hyperbolic time chamber would have a hard time training
             | faster.
        
               | rrss wrote:
               | Thanks for the clarification.
               | 
               | It's cool that google has a ton of money and can give
               | people free access to lots of big servers, but it sounds
               | like what you are excited about (lots of cpu cores, RAM,
               | many VMs) seems to be mostly unrelated to the actual new
               | TPU hardware, which is sorta disappointing.
        
           | manuel_w wrote:
           | > They're pretty rad. You can even get a few (dozen) for
           | free: https://blog.gpt4.org/jaxtpu > Incidentally, that blog
           | runs on a TPU. You can just SSH into them.
           | 
           | As someone without machine learning background, I assumed a
           | TPU is something like a GPU. Aren't they used for ML as well?
           | So I'm surprised you can run linux userspace on it?
        
             | manuel_w wrote:
             | Oh, I read [1] too late. It's a (GNU/?)Linux VM with the
             | special hardware already made available.
             | 
             | [1] https://news.ycombinator.com/item?id=29129554
        
               | sillysaurusx wrote:
               | > So I'm surprised you can run linux userspace on it?
               | 
               | For what it's worth, I was equally shocked! Felt like a
               | miracle.
               | 
               | That special hardware also turns out to be miraculously-
               | easy to use: https://github.com/tensorflow/tensorflow/blo
               | b/master/tensorf...                 // To compile: gcc -o
               | libtpu_client libtpu_client.c -ldl       // To run: sudo
               | ./libtpu_client
               | 
               | I had to do a double-take, because compared to the
               | uniquely hellacious experience of installing CUDA drivers
               | on Ubuntu, this seemed to be... a single header file, and
               | a single .c file.
               | 
               | Turns out, it is that easy:
               | https://twitter.com/theshawwn/status/1400749405356052483
               | 
               | And not because there's some SDK preinstalled -- it's
               | because if you unmask libtpu like a scooby-doo villain,
               | you'll discover libtpu is LuaJIT in disguise. You can
               | even do the equivalent of lua's loadstring() function: ht
               | tps://github.com/tensorflow/tensorflow/blob/dd60c07888b6e
               | 7...
               | 
               | It's just called TpuDriver_CompileProgramFromText instead
               | of loadstring, and you have to write in a quirky assembly
               | language.
               | 
               | But you don't need to write in that quirky assembly
               | language, because you can just `import jax.jit`, jit a
               | function, then dump it as HLO text. So you can just copy-
               | paste it into your C file and run it :)
        
               | rrss wrote:
               | > compared to the uniquely hellacious experience of
               | installing CUDA drivers on Ubuntu, this seemed to be... a
               | single header file, and a single .c file.
               | 
               | > Turns out, it is that easy:
               | https://twitter.com/theshawwn/status/1400749405356052483
               | 
               | > And not because there's some SDK preinstalled
               | 
               | Maybe not an SDK, but AFAICT it is that easy because all
               | the drivers and stuff are already installed.
               | 
               | If you start with the drivers already installed, this is
               | exactly how GPGPU works too. You can write an almost-
               | identical program using OpenCL (or CUDA) instead of
               | libtpu, the only difference is that instead of this HLO
               | the program is in SPIR-V (or OpenCL C, PTX, CUDA C).
               | 
               | And you can start with all the drivers installed by using
               | GCP or AMI images with all the GPU stuffs preinstalled.
               | 
               | It is awesome that this is accessible for free.
        
           | Salgat wrote:
           | Plenty of free services like Google Colab exist for
           | Tensorflow/Pytorch and offer gpu/tpu. I pay just $10 a month
           | to train my models at my leisure on it, and the free version
           | is great if you're doing shorter training runs.
        
           | currymj wrote:
           | PyTorch is working on catching up -- I think they've already
           | got some kind of "vmap" style function transformations in
           | beta. And I'm sure they'll figure out good higher order
           | derivatives too. That's like 90% of what people want out of
           | Jax, so I think they'll be able to compete.
           | 
           | The downside of Jax is it's not easy to debug. PyTorch, for
           | better or for worse, will actually run your Python code as
           | you wrote it.
        
             | sillysaurusx wrote:
             | > The downside of Jax is it's not easy to debug. PyTorch,
             | for better or for worse, will actually run your Python code
             | as you wrote it.
             | 
             | Hmm. Jax's ease of debugging was the very first thing that
             | caught my attention:
             | https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()
             | 
             | > I ran it on the TPU VM, saw the loss curve go down, and
             | it was like an electric shock. "Wow! That actually...
             | worked? Huh. that's weird. Things never work on the first
             | try. I'm impressed."
             | 
             | > Then I plopped `import pdb; pdb.set_trace()` in the
             | middle of the `loss` function and ran it again. It dropped
             | me into the Python debugger.
             | 
             | > There was a tensor named `X_bt`. I typed `X_bt`. The
             | debugger _printed the value of `X_bt`_.
             | 
             | > I was able to print out all the values of every variable,
             | just like you'd expect Python to be able to do.
             | 
             | > There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I
             | was now staring at exactly what I expected: the sum of
             | those two tensors.
             | 
             | > I could write `x + y`, or create new variables, or
             | anything else I wanted.
             | 
             | > Now I was _real_ impressed.
             | 
             | > If it sounds weird that I'm so easily impressed, it's
             | because, you godda understand: until now, TPUs were a
             | complete pain in the ass to use. I kept my feelings to
             | myself, because I understood that the Cloud TPU team were
             | working hard to improve TPUs, and the TFRC support team was
             | wonderful, and I had so many TPUs to play with. But holy
             | moly, if you were expecting any of the above examples to
             | _just work_ on _the first try_ when using Tensorflow V1 on
             | TPUs, you were in for a rude awakening. And if you thought
             | "Well, Tensorflow v2 is supposedly a lot better, right?
             | Surely I'll be able to do basic things without
             | worrying...."
             | 
             | > ... no. Not even close. Not until Jax + TPU VMs.
             | 
             | In the subsequent year, it's been nothing but joy.
             | 
             | If the problem is that you want to see tensor values in a
             | JIT'ed function, use a host callback. You can run actual
             | Python wherever you want: https://jax.readthedocs.io/en/lat
             | est/jax.experimental.host_c...
             | 
             | > This module introduces the host callback functions
             | call(), id_tap(), and id_print(), that send their arguments
             | from the device to the host and invoke user-defined Python
             | functions on the host, optionally returning results back to
             | the device computation.
             | 
             | The nice part is, there's no "magic" under the hood. If you
             | get a chance, I highly recommend reading through Autodidax:
             | https://jax.readthedocs.io/en/latest/autodidax.html
             | 
             | Autodidax is a pure-python implementation of jax.
             | (Literally in one file, on that page.) It walks you through
             | how every aspect of jax works.
             | 
             | Delightfully, I found a secret branch where autodidax also
             | implements host callbacks:
             | https://github.com/google/jax/blob/effect-
             | types/docs/autodid...
             | 
             | If you scroll to the very bottom of that file, you'll see
             | an example of compiling your own XLA JIT'ed code which
             | subsequently calls back into Python. TPUs do precisely the
             | same thing.
             | 
             | Point being:
             | 
             | > PyTorch, for better or for worse, will actually run your
             | Python code as you wrote it.
             | 
             | ... is also true of jax, to within a rounding error less
             | than "I personally don't mind writing id_print(x) instead
             | of print(x)." :)
        
               | currymj wrote:
               | thanks, this is going to be very helpful for me. i guess
               | it's kind of like that old piece of advice, if you want
               | some free Linux tech support, just post "Linux can't do
               | this but Windows can" :)
        
             | 6gvONxR4sf7o wrote:
             | I've found jax's debugging to be in different ways better
             | and worse. The fact that the function transformations are
             | traced is great. It means you can step debug in the tracing
             | steps just as well as the actual eval steps, and you just
             | have jaxpr.Tracers instead of jnp.ndarrays, or whatever.
             | Outside of the transformations, it's just as easy to debug
             | as numpy, which is a blessing. That's one of the biggest
             | selling points.
             | 
             | Debugging jitted and pmapped code, on the other hand, is a
             | pain. Since you can always step out of them to debug, it
             | means that it's debugging performance issues that sucks.
             | And boy does it suck. If anyone knows a good story for
             | figuring out why my jitted thing is slow as hell on TPU,
             | I'm all ears. The profiling section of the official docs is
             | one of their weaker sections. (but big props to the overall
             | documentation quality!)
        
       | WanderPanda wrote:
       | What always kept me from trying jax is the following statement
       | which is pretty prominent on the jax Github Readme
       | 
       | > This is a research project, not an official Google product.
       | Expect bugs and sharp edges. Please help by trying it out,
       | reporting bugs, and letting us know what you think!
       | 
       | Why doesn't Google push more on this in times where Tensorflow is
       | falling behind in mind/market share pretty drastically?!
        
         | p1esk wrote:
         | Because TF is the official one. Pushing Jax would mean giving
         | up on TF, and I don't think Google is ready to do that yet.
        
           | dekhn wrote:
           | the writing is on the wall inside google: pathways is jax,
           | not TF, DM uses Jax instead of TF most of the time, and most
           | new researchers are not adopting TF. It's going to die a slow
           | death like mapreduce.
        
             | tubby12345 wrote:
             | Can't go far on HN these days before some hot take makes
             | you roll your eyes. The number of Borg cells that run DM
             | jobs is a rounding error relative to the number running TF
             | jobs.
        
       ___________________________________________________________________
       (page generated 2021-11-06 23:01 UTC)