[HN Gopher] Show HN: How does JAX allocate memory on a TPU? An i...
___________________________________________________________________
Show HN: How does JAX allocate memory on a TPU? An interactive C++
walkthrough
Author : sillysaurusx
Score : 70 points
Date : 2021-11-06 10:16 UTC (12 hours ago)
(HTM) web link (gist.github.com)
(TXT) w3m dump (gist.github.com)
| alevskaya wrote:
| One thing that's important to note about memory management with
| XLA, is that inside a compiled program there's no user-exposed
| "malloc"/"free". The memory usage and schedule of a given
| program/graph is statically optimized by the compiler (thus JAX
| requires static shapes when jitting). When running in op-by-
| op/eager mode the allocated buffers are coupled to the lifetime
| of the python array object, and are free'd when this array is
| garbage-collected.
| tubby12345 wrote:
| >The memory usage and schedule of a given program/graph is
| statically optimized by the compiler
|
| Is it really though? The only thing I see is
|
| https://github.com/tensorflow/tensorflow/blob/95cdeaa8c848fd...
|
| which traces back to
|
| https://github.com/tensorflow/tensorflow/blob/54a8a3b373918b...
|
| which doesn't anything smart that i can tell.
| azalemeth wrote:
| At the risk of sounding like a HN stereotype, for those who don't
| know what a TPU is:
|
| > TPUs are hardware accelerators specialized in deep learning
| tasks. They are supported in Tensorflow 2.1 both through the
| Keras high-level API, and at a lower level.
| sillysaurusx wrote:
| They're pretty rad. You can even get a few (dozen) for free:
| https://blog.gpt4.org/jaxtpu
|
| Incidentally, that blog runs on a TPU. You can just SSH into
| them.
|
| Stuff like that is why I think Jax will beat pytorch in the
| long run: https://blog.gpt4.org/mlmind
|
| A couple years from now, Metabook might find themselves in real
| trouble. They won't be the React of ML anymore. All the
| researchers who want to get work done have already flocked to
| Jax, so I suspect it's a matter of time till the rest of the
| world notices.
|
| Time will tell. For now, it's by far the most fun I've had in
| my entire career. I've been in ML for awhile now, but my
| background was gamedev, finance, and security -- I encourage
| you to dip your toe into the weird ML scene, because it happens
| to be a blast nowadays.
| boibombeiro wrote:
| Jax uses XLA (a IR that targets multiple platforms, including
| the Google TPU) as backend. Pytorch also has a XLA backend.
| carbocation wrote:
| Underscoring your point, GCP has a tutorial page for
| training Pytorch models on TPU:
| https://cloud.google.com/tpu/docs/tutorials/pytorch-pod
| sillysaurusx wrote:
| A terrible one. It's literally one of the worst experiences
| you can have. I say that without an ounce of bias.
|
| When tpu SSH first came out, I immediately went to my
| pytorch buddy (who became famous in the meantime by
| pioneering the CLIP AI art you've probably been seeing:
| https://mobile.twitter.com/rivershavewings) and said
| "Rivers, you have to try TPUs! They're wonderful! Now you
| can run pytorch on them directly!"
|
| She was skeptical, because she'd had nothing but endless
| problems with the official pytorch TPU notebooks (which use
| a TPU in "remote" mode, aka it fires up an RPC server you
| attach to).
|
| "No no, trust me -- I bet their rpc driver was just
| horrible. It can't be true that the Pytorch XLA backend is
| terrible on TPUs. How could that be? Facebook makes awesome
| software, and they've worked on this for so long. Some
| intern probably wrote the rpc system or something. Anyway,
| you can SSH into mine! See, no RPC server! What's your
| pubkey?"
|
| A few hours later, she reported that it froze in exactly
| the same place.
|
| I could hardly believe it. I was so upset that I dug really
| deeply to try to figure out what the heck the problem was.
|
| It was recompiling itself. Every. Inference.
|
| Every compile took 15 minutes.
|
| Their XLA backend is so far behind Jax that it's not even a
| contest.
|
| Worse, I've since realized _why_ they can't fix the
| problem. JAX gives you precise control over which functions
| get JIT'ed (compiled into XLA), and _when_ that happens.
|
| Pytorch doesn't. There's no @torch.jit decorator for your
| functions.
|
| That means they need to infer which parts to jit (the
| inference model), and which parts not to (the inference
| loop).
|
| The moment that magic fails, congrats; it's now recompiling
| every iteration. (15 min per iteration is impressive, since
| usually it's measured in "iterations per second"...)
|
| JAX even has a TPU compiler cache now, which persists
| across Python runs. It's like MacOS in 2021 vs Windows in
| 1995.
|
| This is deeply frustrating to me, because there's no reason
| for it -- I don't know if Facebook is intentionally
| kneecapping pytorch to keep people off of TPUs, or if it's
| just run of the mill incompetence. But (like Windows 95)
| they're going to discover they're not the only game in town
| forever.
| bertr4nd wrote:
| The pytorch programming model is just really hard to
| adapt to an XLA-like compiler. Imperative python code
| doesn't translate to an ML graph compiler particularly
| well; Jax's API is functional, so it's easier to
| translate to the XLA API. By contrast, torch/xla uses
| "lazy tensors" that record the computation graph and
| compile when needed. The trouble is, if the compute graph
| changes from run to run, you end up recompiling a lot.
|
| I guess in Jax you'd just only apply `jax.jit` to the
| parts where the compute graph is static? I'd be curious
| to see examples of how this works in practice. Fwiw,
| there's an offshoot of pytorch that is aiming to provide
| this sort of API (see
| https://github.com/pytorch/functorch and look at
| eager_compilation.py).
|
| (Disclaimer: I worked on this until quite recently.)
| rrss wrote:
| > that blog runs on a TPU. You can just SSH into them.
|
| By "runs on a TPU," you mean it runs on a CPU in a VM
| somewhere that also has access to TPU hardware, right?
|
| If this is the new TPU phraseology, that's pretty confusing
| IMO.
| sillysaurusx wrote:
| You're right, I spoke carelessly. The proper term is that
| the blog is running on a "TPU VM", which is exactly what
| you describe: a box with /dev/accel0 that libtpu.so uses to
| communicate directly with the TPU hardware.
|
| The difference is, every TPU VM has 96 cpu cores and 350GB
| of RAM. (Pods only get 48 cores per host, but they have 1
| host per 8 TPU cores, and the smallest pod has 32 TPU
| cores, for a whopping 1.2TB of RAM and 196 CPU cores.)
|
| Which is to say, I still think of them as "a TPU", because
| nowhere in the world have I ever been able to access that
| amount of raw horsepower. Not on x86_64 Ubuntu that you can
| pip install things on, at least. Like a blog. :)
|
| Wanna see a magic trick? Clone tensorflow onto a TPU VM and
| start building it. htop will light up like a Christmas tree
| (https://twitter.com/theshawwn/status/1400771262901854214)
| and it'll finish in about 25 minutes flat, if I remember
| correctly.
|
| So yeah, Google is throwing around compute like Israel
| doling out vacations to Tel Aviv for distant relatives:
| it's totally free. Partake!
|
| (I'm really looking forward to seeing Israel someday. I
| never realized how beautiful the beaches are...)
|
| 96 core TPU VMs, free as in beer! It's so exciting that I
| just can't shut up about it.
|
| TRC gives you access to 100 separate VMs the moment you
| sign up.
|
| Having access to 100 VM-yachts totally rules. SSHing into
| 100 of them feels like commanding a WW2 carrier division,
| or something.
|
| It's quite literally too fun: I have to force myself not to
| spend all day unlocking their secrets. There's so much new
| territory to explore -- every day feels like a tremendous
| adventure. Our computing ancestors could only dream of
| exploiting as much hardware as our M1's take for granted,
| let alone one of these behemoths. Let alone one _hundred_!
|
| I went back to look at my old notes. Christmas in 2019 was
| magical, because in January of 2020 I managed to fire up
| santa's 300 TPUs, while a colleague fired up santa's other
| 100 TPUs. Then we swarmed them together into a tornado of
| compute so powerful that even _connecting_ to all 400 TPUs
| required special configuration settings ( "too many open
| files" aka sockets):
| https://twitter.com/theshawwn/status/1221241517626445826
|
| We were training models so fast that I bet even Goku in the
| hyperbolic time chamber would have a hard time training
| faster.
| rrss wrote:
| Thanks for the clarification.
|
| It's cool that google has a ton of money and can give
| people free access to lots of big servers, but it sounds
| like what you are excited about (lots of cpu cores, RAM,
| many VMs) seems to be mostly unrelated to the actual new
| TPU hardware, which is sorta disappointing.
| manuel_w wrote:
| > They're pretty rad. You can even get a few (dozen) for
| free: https://blog.gpt4.org/jaxtpu > Incidentally, that blog
| runs on a TPU. You can just SSH into them.
|
| As someone without machine learning background, I assumed a
| TPU is something like a GPU. Aren't they used for ML as well?
| So I'm surprised you can run linux userspace on it?
| manuel_w wrote:
| Oh, I read [1] too late. It's a (GNU/?)Linux VM with the
| special hardware already made available.
|
| [1] https://news.ycombinator.com/item?id=29129554
| sillysaurusx wrote:
| > So I'm surprised you can run linux userspace on it?
|
| For what it's worth, I was equally shocked! Felt like a
| miracle.
|
| That special hardware also turns out to be miraculously-
| easy to use: https://github.com/tensorflow/tensorflow/blo
| b/master/tensorf... // To compile: gcc -o
| libtpu_client libtpu_client.c -ldl // To run: sudo
| ./libtpu_client
|
| I had to do a double-take, because compared to the
| uniquely hellacious experience of installing CUDA drivers
| on Ubuntu, this seemed to be... a single header file, and
| a single .c file.
|
| Turns out, it is that easy:
| https://twitter.com/theshawwn/status/1400749405356052483
|
| And not because there's some SDK preinstalled -- it's
| because if you unmask libtpu like a scooby-doo villain,
| you'll discover libtpu is LuaJIT in disguise. You can
| even do the equivalent of lua's loadstring() function: ht
| tps://github.com/tensorflow/tensorflow/blob/dd60c07888b6e
| 7...
|
| It's just called TpuDriver_CompileProgramFromText instead
| of loadstring, and you have to write in a quirky assembly
| language.
|
| But you don't need to write in that quirky assembly
| language, because you can just `import jax.jit`, jit a
| function, then dump it as HLO text. So you can just copy-
| paste it into your C file and run it :)
| rrss wrote:
| > compared to the uniquely hellacious experience of
| installing CUDA drivers on Ubuntu, this seemed to be... a
| single header file, and a single .c file.
|
| > Turns out, it is that easy:
| https://twitter.com/theshawwn/status/1400749405356052483
|
| > And not because there's some SDK preinstalled
|
| Maybe not an SDK, but AFAICT it is that easy because all
| the drivers and stuff are already installed.
|
| If you start with the drivers already installed, this is
| exactly how GPGPU works too. You can write an almost-
| identical program using OpenCL (or CUDA) instead of
| libtpu, the only difference is that instead of this HLO
| the program is in SPIR-V (or OpenCL C, PTX, CUDA C).
|
| And you can start with all the drivers installed by using
| GCP or AMI images with all the GPU stuffs preinstalled.
|
| It is awesome that this is accessible for free.
| Salgat wrote:
| Plenty of free services like Google Colab exist for
| Tensorflow/Pytorch and offer gpu/tpu. I pay just $10 a month
| to train my models at my leisure on it, and the free version
| is great if you're doing shorter training runs.
| currymj wrote:
| PyTorch is working on catching up -- I think they've already
| got some kind of "vmap" style function transformations in
| beta. And I'm sure they'll figure out good higher order
| derivatives too. That's like 90% of what people want out of
| Jax, so I think they'll be able to compete.
|
| The downside of Jax is it's not easy to debug. PyTorch, for
| better or for worse, will actually run your Python code as
| you wrote it.
| sillysaurusx wrote:
| > The downside of Jax is it's not easy to debug. PyTorch,
| for better or for worse, will actually run your Python code
| as you wrote it.
|
| Hmm. Jax's ease of debugging was the very first thing that
| caught my attention:
| https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()
|
| > I ran it on the TPU VM, saw the loss curve go down, and
| it was like an electric shock. "Wow! That actually...
| worked? Huh. that's weird. Things never work on the first
| try. I'm impressed."
|
| > Then I plopped `import pdb; pdb.set_trace()` in the
| middle of the `loss` function and ran it again. It dropped
| me into the Python debugger.
|
| > There was a tensor named `X_bt`. I typed `X_bt`. The
| debugger _printed the value of `X_bt`_.
|
| > I was able to print out all the values of every variable,
| just like you'd expect Python to be able to do.
|
| > There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I
| was now staring at exactly what I expected: the sum of
| those two tensors.
|
| > I could write `x + y`, or create new variables, or
| anything else I wanted.
|
| > Now I was _real_ impressed.
|
| > If it sounds weird that I'm so easily impressed, it's
| because, you godda understand: until now, TPUs were a
| complete pain in the ass to use. I kept my feelings to
| myself, because I understood that the Cloud TPU team were
| working hard to improve TPUs, and the TFRC support team was
| wonderful, and I had so many TPUs to play with. But holy
| moly, if you were expecting any of the above examples to
| _just work_ on _the first try_ when using Tensorflow V1 on
| TPUs, you were in for a rude awakening. And if you thought
| "Well, Tensorflow v2 is supposedly a lot better, right?
| Surely I'll be able to do basic things without
| worrying...."
|
| > ... no. Not even close. Not until Jax + TPU VMs.
|
| In the subsequent year, it's been nothing but joy.
|
| If the problem is that you want to see tensor values in a
| JIT'ed function, use a host callback. You can run actual
| Python wherever you want: https://jax.readthedocs.io/en/lat
| est/jax.experimental.host_c...
|
| > This module introduces the host callback functions
| call(), id_tap(), and id_print(), that send their arguments
| from the device to the host and invoke user-defined Python
| functions on the host, optionally returning results back to
| the device computation.
|
| The nice part is, there's no "magic" under the hood. If you
| get a chance, I highly recommend reading through Autodidax:
| https://jax.readthedocs.io/en/latest/autodidax.html
|
| Autodidax is a pure-python implementation of jax.
| (Literally in one file, on that page.) It walks you through
| how every aspect of jax works.
|
| Delightfully, I found a secret branch where autodidax also
| implements host callbacks:
| https://github.com/google/jax/blob/effect-
| types/docs/autodid...
|
| If you scroll to the very bottom of that file, you'll see
| an example of compiling your own XLA JIT'ed code which
| subsequently calls back into Python. TPUs do precisely the
| same thing.
|
| Point being:
|
| > PyTorch, for better or for worse, will actually run your
| Python code as you wrote it.
|
| ... is also true of jax, to within a rounding error less
| than "I personally don't mind writing id_print(x) instead
| of print(x)." :)
| currymj wrote:
| thanks, this is going to be very helpful for me. i guess
| it's kind of like that old piece of advice, if you want
| some free Linux tech support, just post "Linux can't do
| this but Windows can" :)
| 6gvONxR4sf7o wrote:
| I've found jax's debugging to be in different ways better
| and worse. The fact that the function transformations are
| traced is great. It means you can step debug in the tracing
| steps just as well as the actual eval steps, and you just
| have jaxpr.Tracers instead of jnp.ndarrays, or whatever.
| Outside of the transformations, it's just as easy to debug
| as numpy, which is a blessing. That's one of the biggest
| selling points.
|
| Debugging jitted and pmapped code, on the other hand, is a
| pain. Since you can always step out of them to debug, it
| means that it's debugging performance issues that sucks.
| And boy does it suck. If anyone knows a good story for
| figuring out why my jitted thing is slow as hell on TPU,
| I'm all ears. The profiling section of the official docs is
| one of their weaker sections. (but big props to the overall
| documentation quality!)
| WanderPanda wrote:
| What always kept me from trying jax is the following statement
| which is pretty prominent on the jax Github Readme
|
| > This is a research project, not an official Google product.
| Expect bugs and sharp edges. Please help by trying it out,
| reporting bugs, and letting us know what you think!
|
| Why doesn't Google push more on this in times where Tensorflow is
| falling behind in mind/market share pretty drastically?!
| p1esk wrote:
| Because TF is the official one. Pushing Jax would mean giving
| up on TF, and I don't think Google is ready to do that yet.
| dekhn wrote:
| the writing is on the wall inside google: pathways is jax,
| not TF, DM uses Jax instead of TF most of the time, and most
| new researchers are not adopting TF. It's going to die a slow
| death like mapreduce.
| tubby12345 wrote:
| Can't go far on HN these days before some hot take makes
| you roll your eyes. The number of Borg cells that run DM
| jobs is a rounding error relative to the number running TF
| jobs.
___________________________________________________________________
(page generated 2021-11-06 23:01 UTC)