[HN Gopher] Training Deep Networks with Data Parallelism in Jax
___________________________________________________________________
Training Deep Networks with Data Parallelism in Jax
Author : sebg
Score : 65 points
Date : 2023-02-24 17:05 UTC (5 hours ago)
(HTM) web link (www.mishalaskin.com)
(TXT) w3m dump (www.mishalaskin.com)
| qmatch wrote:
| A nice simple walkthrough in this post, but would be nice if it
| was updated to show how to do this with sharding and the new
| jax.Array type introduced not too long ago
|
| https://github.com/google/jax/pull/11233/files
| amrb wrote:
| So I've been looking into ONNX to speedup inference, is there
| some killer feature I should look at JAX for?
| chas wrote:
| Jax is a great tool, but it's really best for training and
| experimentation. The transformations outlined in this post
| (amongst others) make it easy to turn simple and
| straightforward code into high performance parallel code. While
| this is changing, inferences hasn't been a historical area of
| emphasis for the project, so it wouldn't be my first choice if
| that was your primary goal.
| UncleOxidant wrote:
| About a year ago I was tasked with comparing ONNX runtime
| implementations of certain components like convolution with
| some of our own in-house implementations. There was just no
| comparison. ONNX runtime has some pretty crazy fast
| implementations. Lots of magic in there. Concluded that we
| weren't going to be able to beat those without a lot of effort
| and expertise that we didn't have in our team.
| kkielhofner wrote:
| Generally from what I've seen the biggest inference speedup win
| with ONNX is to get the model to ONNX then to TRT (TensorRT) -
| assuming Nvidia hardware. Once in "TRT land" you can play with
| fp32, fp16, int8 (with calibration), etc. That said (again
| generally) ONNX does tend to perform better when compared to
| native (TF savedmodel, pytorch torchscript, whatever). With
| TensorFlow and Pytorch there are also ways to export/compile
| directly to TRT but from an ops standpoint I don't prefer this.
|
| Certain inference serving solutions like Nvidia Triton
| Inference Server will even take an ONNX model and then do TRT
| compilation (with cache!) on the actual inference hardware
| dynamically at model load time. This is really nice because you
| can deploy a standard ONNX model across instances and varying
| GPU hardware and always get TRT optimized and compatible with
| Compute Capability, etc. Really handy and basically comes down
| to a few lines of config in the model configuration.
|
| I'm not terribly familiar with JAX but I have to imagine
| there's ONNX export or straight to TRT export somewhere.
| mccoyb wrote:
| JAX is such a beautiful system. There's many deep PL ideas in the
| design of JAX, one could spend years thinking about them. It's
| wonderfully fun to design new interpreters, implementing some
| semantics, stage them out -- automatically gain access to JIT
| compilation and accelerators via JAX's other higher-order
| primitives/transformations.
|
| I've become a big believer that it would be beneficial for PL
| research in ML which makes heavy use of program transformations
| to provide small JAX-based implementations. There's really no
| other system which allows you to express interpreter-based
| transformations with the benefits that JAX provides (maybe
| `functorch` in a few months? I have some doubts of transformation
| composition with systems like torchdynamo - but I don't know much
| about it)
|
| Edit: note this is coming from a long time Julia stan, take that
| for what it is worth :)
| UncleOxidant wrote:
| Since you're a Julia stan: Don't you think that the program
| transformations that JAX is doing could be done much more
| easily in Julia since Julia has macros? Aren't there things
| that are similar to JAX in the Julia ecosystem? (ie. a few
| different autodiff packages that do program transformation)
| mccoyb wrote:
| No -- macros are a non-intrusive transformation -- they don't
| allow you to transform callees of a function, unless you wrap
| the macro around the callee. People have tried this in Julia,
| and it's horribly slow.
|
| There's another mechanism in Julia - generated functions.
| These allow method body specialization given type knowledge
| about the signature of the function -- so a user can write
| code which is generated for the method body when inference
| determines the signature (and the inferred signature is tight
| enough) which depends on the inferred types.
|
| All of Julia's program transformation based AD packages are
| based on the latter transformation -- most of them do
| terrible things to the compiler, including massively blowing
| up the size of code before optimization.
|
| The only package which is more promising is Diffractor -- but
| I'm not convinced it is more than a research prototype at its
| current level of development. That may change. This was
| written by one of the compiler devs, and uses lower level
| hooks into the compiler, developed to support its
| transformation.
|
| The big issue in general: Julia doesn't let you write
| transformations on its typed IR from user space, unless you
| want to ignore Julia's native execution engine. There are
| hooks that someone can work with -- but they aren't user-
| facing (for all but the most advanced users) -- and they
| break pass composability with code generation using the
| native engine (this may have changed since I last looked at
| this!) I would know, because I've spent several attempts
| trying to do stuff like this, and making crappy, unstable
| packages :)
|
| Separately: macros are one level of reflection -> code
| generation. JAX supports a different form -- you can't emit
| data which represents generic expressions -- it's not quite
| like Lisp in that sense. It's better to think about JAX as a
| "two-level" language system -- where you have a meta-level
| which is Python, and there's a statically typed array
| language which is the object level. JAX supports a stage
| -like operation which allows transforming compat subset of
| Python to the statically typed array language. But you can
| write interpreters in the full dynamism of Python -- as long
| as the "active paths" (under tracers) are in that compat set,
| you can then stage out applications of the interpreters on
| Python functions, etc.
|
| JAX provides one solution to the composable transformation
| problem, and they've done it in an elegant way - that's
| ~pretty~ easy to understand (c.f. the Autodidax tutorial).
| With my current knowledge of things, I can't effectively
| argue that Julia supports the same right now (caveat: things
| may have changed since I last had a look). This is an area
| where a lot of stuff seems to be going on in Julia, so I
| doubt it will remain this way forever.
| jdeaton wrote:
| The abstractions provided by JAX for parallelism are beautiful.
| JAX is an absolute master-class in programming-interface design
| and a lesson in the power of providing composable primitive
| operations and FP inspired design. An astounding amount of
| complexity is hidden from the user behind primitives like pmap,
| and the power is exposed in such a simple interface.
| alfalfasprout wrote:
| Agreed. Though keep in mind they built on a lot of failed
| attempts at doing the same to get here.
| 6gvONxR4sf7o wrote:
| That's true, and is a massive part of what I love about JAX,
| but they also form barriers in weird parts of your code,
| preventing standard introspection tools, which is the single
| thing I hate about JAX. The errors are amazingly opaque.
| mattjjatgoogle wrote:
| If you have any particular examples in mind, and time to
| share them on https://github.com/google/jax/issues, we'd love
| to try to improve them. Improving error messages is a
| priority.
|
| About introspection tools, at least for runtime value
| debugging there is to some extent a fundamental challenge:
| since jax.jit stages computation out of Python (though
| jax.grad and jax.vmap don't), it means standard Python
| runtime value inspection tools, like printing and pdb, can't
| work under a jax.jit as the values aren't available as the
| Python code is executing. You can always remove the jax.jit
| while debugging (or use `with jax.disable_jit(): ...`), but
| that's not always convenient, and we need jax.jit for good
| performance.
|
| We recently added some runtime value debugging tools which
| work even with jax.jit-staged-out code (even in automatically
| parallelized code!), though they're not the standard
| introspection tools: see `jax.debug.print` and
| `jax.debug.breakpoint` on
| https://jax.readthedocs.io/en/latest/debugging/index.html and
| https://jax.readthedocs.io/en/latest/debugging/print_breakpo.
| ...
|
| If you were thinking about other kinds of introspection
| tooling, I'd love to hear about it!
| 6gvONxR4sf7o wrote:
| > with jax.disable_jit(): ...
|
| That's handy, and I hadn't seen it before, thanks.
|
| It's been a bit, but I think the most frustrating errors
| were around mapping pytrees (like this issue
| https://github.com/google/jax/issues/9928). I'm not sure
| the exact solution, but the axis juggling and
| specifications were where I remember a lot of pain, and the
| docs (though extensive) were unclear. At times it feels
| like improvements are punted on in the hopes that xmap
| eventually fixes everything (and xmap has been in
| experimental for far longer than I expected).
|
| Also the barriers where I couldn't disable jit. IIRC pmap
| automatically jits, so there was no way to avoid staging
| that part out. When it came to doing some complex
| jax.lax.ppermute, it felt more difficult than it needed to
| be to debug.
|
| Next time I encounter something particularly opaque, I'll
| share on the github issue tracker.
| mattjjatgoogle wrote:
| Thanks for taking the time to explain these.
|
| > It's been a bit, but I think the most frustrating
| errors were around mapping pytrees (like this issue
| https://github.com/google/jax/issues/9928).
|
| We've improved some of these pytree error messages but it
| seems that vmap one is still not great. Thanks for the
| ping on it.
|
| > Also the barriers where I couldn't disable jit. IIRC
| pmap automatically jits, so there was no way to avoid
| staging that part out.
|
| That was indeed a longstanding issue in pmap's
| implementation. And since people came to expect jit to be
| "built in" to pmap, it wasn't easy to revise.
|
| However, we recently
| (https://github.com/google/jax/pull/11854) made
| `jax.disable_jit()` work with pmap, in the sense that it
| makes pmap execute eagerly, so that you can print/pdb/etc
| to your heart's content. (The pmap successor, shard_map
| (https://jax.readthedocs.io/en/latest/jep/14273-shard-
| map.htm...), is eager by default. Also it has uniformly
| good error messages from the start!)
|
| > Next time I encounter something particularly opaque,
| I'll share on the github issue tracker.
|
| Thank you for the constructive feedback!
| 6gvONxR4sf7o wrote:
| Thanks! One last thing, since I have your ear. The
| function transformation aspects of jax seem to make their
| way into downstream libraries like haiku, resulting in a
| lot of "magic" that can be difficult to examine and
| debug. Are there any utils you made to make jax's own
| transformations more transparent, which you think might
| be helpful to third party transformations?
|
| Higher order functions are difficult in general, and it
| would be fantastic to have core patterns or tools for
| breaking them open.
| mattjjatgoogle wrote:
| Thanks for the kind words! We've been doing a lot more work in
| this direction too, for both compiler-based automatic
| parallelization [0] and a work-in-progress pmap successor for
| 'manual' parallelism (per-device code and explicit collectives)
| [1] which composes seamlessly with the compiler-based stuff.
|
| [0]
| https://jax.readthedocs.io/en/latest/notebooks/Distributed_a...
|
| [1] https://jax.readthedocs.io/en/latest/jep/14273-shard-
| map.htm...
| uptownfunk wrote:
| It seems like the ecosystem is still dominated by PyTorch, is Jax
| supposed to be a competitor? Any signs of Jax taking over PyTorch
| anytime soon? Is it perhaps too early for its time? Or is there a
| critical flaw in the underlying design?
| mccoyb wrote:
| I think it's young - and perhaps JAX itself is not so
| specialized to a specific task (but there's plenty of libraries
| for deep learning focused tooling, although not as mature as
| PyTorch). It has often been said in other threads on JAX, but
| it feels like a very different type of library from other AD
| systems -- the focus on concisely exposing/allowing users to
| express composable transformations seems novel! (But I may be
| mistaken)
|
| But in general, I would suspect youth.
| time_to_smile wrote:
| I think it's better to think of JAX as a more general framework
| for differentiable programming and PyTorch more focused
| specifically on deep learning/neural networks.
|
| The beauty of JAX is that basic usage is basically a single
| function: `grad`.
|
| You just write whatever Python function you want and can get
| the derivative/gradient of it trivially. It gets a bit trickier
| when you need more sophisticated numeric tools like
| numpy/scipy, but in those cases it's just about swapping out
| with a JAX version of those.
|
| In this sense JAX is the spiritual success to Autograd. However
| the really amazing thing about JAX is that not only do you get
| the autodiff for basically free, you also get very good
| performance, and basically GPU parallelism without needing to
| think about it at all.
|
| PyTorch is an awesome library, but largely focus on building
| Neural Networks specifically. JAX should be thought of a tool
| that basically any Python programmer can just throw in there
| whenever they come across a problem that benefits from having
| differentiable code (which is a lot of cases once you start
| thinking about differentiation as a first class feature).
___________________________________________________________________
(page generated 2023-02-24 23:00 UTC)