[HN Gopher] Penzai: JAX research toolkit for building, editing, ...
___________________________________________________________________
Penzai: JAX research toolkit for building, editing, and visualizing
neural nets
Author : mccoyb
Score : 158 points
Date : 2024-04-21 16:28 UTC (6 hours ago)
(HTM) web link (github.com)
(TXT) w3m dump (github.com)
| ein0p wrote:
| Looks great, but outside Google I do not personally know anyone
| who uses Jax, and I work in this space.
| yshvrdhn wrote:
| it would be great if we can have intelligent tools for building
| neural networks in pytorch.
| Edmond wrote:
| would a comprehensive object construction platform with
| schema support and the ability to hookup to a compiler (ie
| turn object data to code for instance) be a useful tool in
| this domain?
|
| ex: https://www.youtube.com/watch?v=fPnD6I9w84c
|
| I am the developer, happy to answer questions.
| polygamous_bat wrote:
| A small addendum: the only people I know who uses Jax are
| people who work at Google, or people who had a big GCP grant
| and needed to use TPUs as a result.
| j7ake wrote:
| I'm in academia and I use jax because it's closest to translate
| maths to code.
| hyperbovine wrote:
| Same, Jax is extremely popular with the applied math/modeling
| crowd.
| error9348 wrote:
| Jax trends on papers with code:
|
| https://paperswithcode.com/trends
| ein0p wrote:
| Note that most of Jax's minuscule share is Google.
| nostrademons wrote:
| Was gonna ask "What's that MindSpore thing that seems to be
| taking the research world by storm?" but I Googled and it's
| apparently Huawei's open-source AI framework. 1% to 7% market
| share in 2 years is nothing to sneeze at - that's growth
| rates similar to Chrome or Facebook in their heyday.
|
| It's telling that Huawei-backed MindSpore can go from 1% to
| 7% in 2 years, while Google-backed Jax is stuck at 2-3%.
| Contrary to popular narrative in the Western world, Chinese
| dominance is alive and well.
| logicchains wrote:
| >It's telling that Huawei-backed MindSpore can go from 1%
| to 7% in 2 years, while Google-backed Jax is stuck at 2-3%.
| Contrary to popular narrative in the Western world, Chinese
| dominance is alive and well.
|
| MindSpore has an advantage there because of its integrated
| support for Huawei's Ascend 910B, the only Chinese GPU that
| comes close to matching the A100. Given the US banned
| export of A100 and H100s to China, this creates artificial
| demand for the Ascend 910B chips and the MindSpore
| framework that utilises them.
| bigcat12345678 wrote:
| No, mindspore rises because of the chip embargo
|
| No one is going to use stuff that one day is cut off
| supply.
|
| This is one signal why Huawei was listed by Nvidia as
| competitor in 4 out of 5 categories of areas, in nvidia's
| earnings
| ein0p wrote:
| Its meteoric rise started well before the chip embargo.
| I've looked into it, it liberally borrows ideas from
| other frameworks, both PyTorch and Jax, and adds some of
| its own. You lose some of the conceptual purity, but it
| makes up for it in practical usability, assuming it works
| as it says on the tin, which it may or may not. PyTorch
| also has support for Ascend as far as I can tell
| https://github.com/Ascend/pytorch, so that support does
| not necessarily explain MindSpore's relative success. Why
| MindSpore is rising so rapidly is not entirely clear to
| me. Could be something as simple as preferring a domestic
| alternative that is adequate to the task and has better
| documentation in Chinese. Could be cost of compute. Could
| be both. Nowadays, however, I do agree that the various
| embargoes would help it (as well as Huawei) a great deal.
| As a side note I wish Huawei could export its silicon to
| the West. I bet that'd result in dramatically cheaper
| compute.
| creato wrote:
| This data might just be unreliable. It had a weird spike
| in Dec 2021 that looks unusual compared to all the other
| frameworks.
| VHRanger wrote:
| China publishes a looooootttttt of papers. A lot of it is
| careerist crap.
|
| To be fair, a lot of US papers are also crap, but Chinese
| crap research is on another level. There's a reason a lot
| of top US researchers are Chinese - there's brain drain
| going on.
| mccoyb wrote:
| That's cool -- but wouldn't it be more constructive to discuss
| "the ideas" in this package anyways?
|
| For instance, it would be interesting to discern if the design
| of PyTorch (and their modules) preclude or admit the same sort
| of visualization tooling? If you have expertise in PyTorch,
| perhaps you could help answer this sort of question?
|
| JAX's Pytrees are like "immutable structs, with array leaves"
| -- does PyTorch have a similar concept?
| ein0p wrote:
| Idk if you need that immutability actually. You could
| probably reconstruct enough to do this kind of viz from the
| autograd graph, or capture the graph and intermediates in the
| forward pass using hooks. My hunch is it should be doable.
| fpgamlirfanboy wrote:
| > does PyTorch have a similar concept
|
| of course https://github.com/pytorch/pytorch/blob/main/torch/
| utils/_py...
| _ntka wrote:
| Isn't JAX the most widely used framework in the GenAI space?
| Most companies there use it -- Cohere, Anthropic, CharacterAI,
| xAI, Midjourney etc.
| mistrial9 wrote:
| just guessing that tech leadership at all of those traces
| back to Google somehow
| logicchains wrote:
| Not at Google but currently using Jax to leverage TPUs, because
| AWS's GPU pricing is eye-gougingly expensive. For the lower-end
| A10 GPUs, the price-per-gpu for a 4 GPU machine is 1.5x the
| price for a 1 GPU machine, and the price-per-gpu for a 8 GPU
| machine is 2x the price of a 1 GPU machine! If you want a A100
| or H100, the only option is renting an 8 GPU instance. With
| properly TPU-optimised code you get something like 30-50% cost
| saving on GCP TPUs compared to AWS (and I say that as someone
| who otherwise doesn't like Google as a company and would prefer
| to avoid GCP if there wasn't such a significant cost
| advantage).
| KeplerBoy wrote:
| I use it for GPU accelerated signal processing. It really
| delivers on the promise of "Numpy but for GPU" better than all
| competing libraries out there.
| sudosysgen wrote:
| I use it all the time, and there's also a few classes at my uni
| that use Jax. It's really great for experimentation and
| research, you can do a lot of things in Jax you just can't in,
| say, PyTorch.
| MasterScrat wrote:
| We've built our startup from scratch on JAX, selling text-to-
| image model finetuning, and it's given us a consistent edge not
| only in terms of pure performance but also in terms of "dollars
| per unit of work"
| eli_gottlieb wrote:
| If JAX had affine_grid() and grid_sample(), I'd be using it
| instead of PyTorch for my current project.
| catgary wrote:
| I've only been reading through the docs for a few moments, but
| I'm pleasantly surprised to find they the authors are using
| effect handlers to handle effectful computations in ML models. I
| was in the process of translating a model from torch to Jax using
| Equinox, this makes me think penzai could be a better choice.
| patrickkidger wrote:
| I was just reading this too! I think it's a really interesting
| choice in the design space.
|
| So to elucidate this a little bit, the trade-off is that this
| is now incompatible with e.g. `jax.grad` or `lax.scan`: you
| can't compose things in the order
| `discharge_effect(jax.grad(your_model_here))`, or put an
| effectful `lax.scan` inside your forward pass, etc. The effect-
| discharging process only knows how to handle traversing pytree
| structures. (And they do mention this at the end of their
| docs.)
|
| This kind of thing was actually something I explicitly
| considered later on in Equinox, but in part decided against as
| I couldn't see a way to make that work either. The goal of
| Equinox was always absolute compatibility with arbitrary JAX
| code.
|
| Now, none of that should be taken as a bash at Penzai! They've
| made a different set of trade-offs, and if the above
| incompatibility doesn't affect your goals then indeed their
| effect system is incredibly elegant, so certainly give it a
| try. (Seriously, it's been pretty cool to see the release of
| Penzai, which explicitly acknowledges how much it's inspired by
| Equinox.)
| ddjohnson wrote:
| Author of Penzai here! In idiomatic Penzai usage, you should
| always discharge all effects before running your model. While
| it's true you can't do
| `discharge_effect(jax.grad(your_model_here))`, you can still
| do `jax.grad(discharge_effect(your_model_here))`, which is
| probably what you meant to do anyway in most cases. Once
| you've wrapped your model in a handler layer, it has a pure
| interface again, which makes it fully compatible with all
| arbitrary JAX transformations. The intended use of effects is
| as an internal helper to simplify plumbing of values into and
| out of layers, not as something that affects the top-level
| interface of using the model!
|
| (As an example of this, the GemmaTransformer example model
| uses the SideInput effect internally to do attention masking.
| But it exposes a pure functional interface by using a handler
| internally, so you can call it anywhere you could call an
| Equinox model, and you shouldn't have to think about the
| effect system at all as a user of the model.)
|
| It's not clear to me what the semantics of ordinary JAX
| transformations like `lax.scan` should be if the model has
| side effects. But if you don't have any effects in your
| model, or if you've explicitly handled them already, then
| it's perfectly fine to use `lax.scan`. This is similar to how
| it works in ordinary JAX; if you try to do a `lax.scan` over
| a function that mutates Python state, you'll probably hit an
| error or get something unexpected. But if you mutate Python
| state internally inside `lax.scan`, it works fine.
|
| I'll also note that adding support for higher-order layer
| combinators (like "layer scan") is something that's on the
| roadmap! The goal would be to support some of the fancier
| features of libraries like Flax when you need them, while
| still admitting a simple purely-functional mental model when
| you don't.
| ddjohnson wrote:
| Thanks! This is one of the more experimental design choices I
| made in designing Penzai, but so far I've found it to be quite
| useful.
|
| The effect system does come with a few sharp edges at the
| moment if you want to use JAX transformations inside the
| forward pass of your model (see my reply to Patrick), but I'm
| hoping to make it more flexible as time goes on. (Figuring out
| how effect systems should compose with function transformations
| is a bit nontrivial!)
|
| Please let me know if you run into any issues using Penzai for
| your model! (Also, most of Penzai's visualization and patching
| utilities should work with Equinox too, so you shouldn't
| necessarily need to fully commit to either one.)
| ubj wrote:
| Does anyone know if and how well Penzai can work with Diffrax
| [1]? I currently use Diffrax + Equinox for scientific machine
| learning. Penzai looks like an interesting alternative to
| Equinox.
|
| [1]: https://docs.kidger.site/diffrax/
| thatguysaguy wrote:
| Not sure on the specific combination, but since everything in
| Jax is functionally pure it's generally really easy to compose
| libraries. E.g. I've written code which embedded a flax model
| inside a haiku model without much effort.
| patrickkidger wrote:
| IIUC then penzai is (deliberately) sacrificing support for
| higher-order operations like `lax.{while_loop, scan, cond}` or
| `diffrax.diffeqsolve`, in return for some of the other new
| features it is trying out (treescope, effects).
|
| So it's slightly more framework-y than Equinox and will not be
| completely compatible with arbitrary JAX code. However I have
| already had a collaborator demonstrate that as long as you _don
| 't_ use any higher-order operations, then treescope will
| actually work out-of-the-box with Equinox modules!
|
| So I think the answer to your question is "sort of":
|
| * As long as you only try to inspect things that are happening
| outside of your `diffrax.diffeqsolve` then you should be good
| to go. And moreover can probably do this simply by using e.g.
| Penzai's treescope directly alongside your existing Equinox
| code, without needing to move things over wholesale.
|
| * But anything inside probably isn't supported + if I
| understand their setup correctly can never be supported. (Not
| bashing Penzai there, which I think genuinely looks excellent
| -- I think it's just fundamentally tricky at a technical
| level.)
| ddjohnson wrote:
| Author of Penzai here. I think the answer is a bit more
| nuanced (and closer to "yes") than this:
|
| - If you want to use the treescope pretty-printer or the
| pz.select tree manipulation utility, those should work out-
| of-the-box with both Equinox and Diffrax. Penzai's utilities
| are designed to be as modular as possible (we explicitly try
| not to be "frameworky") so they support arbitrary JAX
| pytrees; if you run into any problems with this please file
| an issue!
|
| - If you want to call a Penzai model inside
| `diffrax.diffeqsolve`, that should also be fully supported
| out of the box. Penzai models expose a pure functional
| interface when called, so you should be able to call a Penzai
| model anywhere that you'd call an Equinox model. From the
| perspective of the model user, you should be able to think of
| the effect system as an implementation detail. Again, if you
| run into problems here, please file an issue.
|
| - If you want to write your own Penzai layer that uses
| `diffrax.diffeqsolve` internally, that should also work. You
| can put arbitrary logic inside a Penzai layer as long as it's
| pure.
|
| - The specific thing that is _not_ currently fully supported
| is: (1) defining a higher-order Penzai combinator _layer_
| that uses `diffrax.diffeqsolve` internally, (2) and having
| that layer run one of its sublayers inside the
| `diffrax.diffeqsolve` function, (3) while simultaneously
| having that internal sublayer use an effect (like random
| numbers, state, or parameter sharing), (4) where the handler
| for that effect is placed _outside_ of the combinator layer.
| This is because the temporary effect implementation node that
| gets inserted while a handler is running isn 't a JAX array
| type, so you'll get a JAX error when you try to pass it
| through a function transformation.
|
| This last case is something I'd like to support as well, but
| I still need to figure out what the semantics of it should
| be. (E.g. what does it even mean to solve a differential
| equation that has a local state variable in it?) I think
| having side effects inside a transformed function is
| fundamentally hard to get right!
| pizza wrote:
| I remember pytorch has some pytree capability, no? So is it safe
| to say that the any-pytree-compatible modules here are already
| compatible w/ pytorch?
| ddjohnson wrote:
| Author here! I didn't know PyTorch had its own pytree system.
| It looks like it's separate from JAX's pytree registry, though,
| so Penzai's tooling probably won't work with PyTorch out of the
| box.
| sillysaurusx wrote:
| I implemented Jax's pytrees in pure python. You can use it with
| whatever you want. https://github.com/shawwn/pytreez
|
| The readme is a todo, but the tests are complete. They're the
| same that Jax itself uses, but zero dependencies.
| https://github.com/shawwn/pytreez/blob/master/tests/test_pyt...
|
| The concept is simple. The hard part is cross pollination.
| Suppose you wanted to literally use Jax pytrees with PyTorch.
| Now you'll have to import Jax, or my library, and register your
| modules with it. But anything else that ever uses pytrees need
| to use the same pytree library, because the registry (the thing
| that keeps track of pytree compatible classes) is in the
| library you choose. They don't share registries.
|
| A better way of phrasing it is that if you use a jax-style
| pytree interface, it should work with any other pytree library.
| But to my knowledge, the only pytree library besides Jax itself
| is mine here, and only I use it. So when you ask if pytree-
| compatible modules are compatible with PyTorch, it's equivalent
| to asking whether PyTorch projects use jax, and the answer
| tends to be no.
|
| EDIT: perhaps I'm outdated. OP says that PyTorch has pytree
| functionality now.
| https://news.ycombinator.com/item?id=40109662 I guess yet again
| I was ahead of the times by a couple years; happy to see other
| ecosystems catch up. Hopefully seeing a simple implementation
| will clarify the tradeoffs.
|
| The best approach for a universal pytree library would be to
| assume that any class with tree_flatten and tree_unflatten
| methods are pytreeable, and not require those classes to be
| explicitly registered. That way you don't have to worry whether
| you're using Jax or PyTorch pytrees. But I gave up trying to
| make library-agnostic ML modules; in practice it's better just
| to choose Jax or PyTorch and be done with it, since making
| PyTorch modules run in Jax automatically (and vice versa) is a
| fool's errand (I was the fool, and it was an errand) for many
| reasons, not the least of which is that Jax builds an explicit
| computation graph via jax.jit, a feature PyTorch has only
| recently (and reluctantly) embraced. But of course, that means
| if you pick the wrong ecosystem, you'll miss out on the best
| tools -- hello React vs Vue, or Unreal Engine vs Unity, or
| dozens of other examples.
___________________________________________________________________
(page generated 2024-04-21 23:00 UTC)