[HN Gopher] Penzai: JAX research toolkit for building, editing, ...
       ___________________________________________________________________
        
       Penzai: JAX research toolkit for building, editing, and visualizing
       neural nets
        
       Author : mccoyb
       Score  : 158 points
       Date   : 2024-04-21 16:28 UTC (6 hours ago)
        
 (HTM) web link (github.com)
 (TXT) w3m dump (github.com)
        
       | ein0p wrote:
       | Looks great, but outside Google I do not personally know anyone
       | who uses Jax, and I work in this space.
        
         | yshvrdhn wrote:
         | it would be great if we can have intelligent tools for building
         | neural networks in pytorch.
        
           | Edmond wrote:
           | would a comprehensive object construction platform with
           | schema support and the ability to hookup to a compiler (ie
           | turn object data to code for instance) be a useful tool in
           | this domain?
           | 
           | ex: https://www.youtube.com/watch?v=fPnD6I9w84c
           | 
           | I am the developer, happy to answer questions.
        
         | polygamous_bat wrote:
         | A small addendum: the only people I know who uses Jax are
         | people who work at Google, or people who had a big GCP grant
         | and needed to use TPUs as a result.
        
         | j7ake wrote:
         | I'm in academia and I use jax because it's closest to translate
         | maths to code.
        
           | hyperbovine wrote:
           | Same, Jax is extremely popular with the applied math/modeling
           | crowd.
        
         | error9348 wrote:
         | Jax trends on papers with code:
         | 
         | https://paperswithcode.com/trends
        
           | ein0p wrote:
           | Note that most of Jax's minuscule share is Google.
        
           | nostrademons wrote:
           | Was gonna ask "What's that MindSpore thing that seems to be
           | taking the research world by storm?" but I Googled and it's
           | apparently Huawei's open-source AI framework. 1% to 7% market
           | share in 2 years is nothing to sneeze at - that's growth
           | rates similar to Chrome or Facebook in their heyday.
           | 
           | It's telling that Huawei-backed MindSpore can go from 1% to
           | 7% in 2 years, while Google-backed Jax is stuck at 2-3%.
           | Contrary to popular narrative in the Western world, Chinese
           | dominance is alive and well.
        
             | logicchains wrote:
             | >It's telling that Huawei-backed MindSpore can go from 1%
             | to 7% in 2 years, while Google-backed Jax is stuck at 2-3%.
             | Contrary to popular narrative in the Western world, Chinese
             | dominance is alive and well.
             | 
             | MindSpore has an advantage there because of its integrated
             | support for Huawei's Ascend 910B, the only Chinese GPU that
             | comes close to matching the A100. Given the US banned
             | export of A100 and H100s to China, this creates artificial
             | demand for the Ascend 910B chips and the MindSpore
             | framework that utilises them.
        
               | bigcat12345678 wrote:
               | No, mindspore rises because of the chip embargo
               | 
               | No one is going to use stuff that one day is cut off
               | supply.
               | 
               | This is one signal why Huawei was listed by Nvidia as
               | competitor in 4 out of 5 categories of areas, in nvidia's
               | earnings
        
               | ein0p wrote:
               | Its meteoric rise started well before the chip embargo.
               | I've looked into it, it liberally borrows ideas from
               | other frameworks, both PyTorch and Jax, and adds some of
               | its own. You lose some of the conceptual purity, but it
               | makes up for it in practical usability, assuming it works
               | as it says on the tin, which it may or may not. PyTorch
               | also has support for Ascend as far as I can tell
               | https://github.com/Ascend/pytorch, so that support does
               | not necessarily explain MindSpore's relative success. Why
               | MindSpore is rising so rapidly is not entirely clear to
               | me. Could be something as simple as preferring a domestic
               | alternative that is adequate to the task and has better
               | documentation in Chinese. Could be cost of compute. Could
               | be both. Nowadays, however, I do agree that the various
               | embargoes would help it (as well as Huawei) a great deal.
               | As a side note I wish Huawei could export its silicon to
               | the West. I bet that'd result in dramatically cheaper
               | compute.
        
               | creato wrote:
               | This data might just be unreliable. It had a weird spike
               | in Dec 2021 that looks unusual compared to all the other
               | frameworks.
        
               | VHRanger wrote:
               | China publishes a looooootttttt of papers. A lot of it is
               | careerist crap.
               | 
               | To be fair, a lot of US papers are also crap, but Chinese
               | crap research is on another level. There's a reason a lot
               | of top US researchers are Chinese - there's brain drain
               | going on.
        
         | mccoyb wrote:
         | That's cool -- but wouldn't it be more constructive to discuss
         | "the ideas" in this package anyways?
         | 
         | For instance, it would be interesting to discern if the design
         | of PyTorch (and their modules) preclude or admit the same sort
         | of visualization tooling? If you have expertise in PyTorch,
         | perhaps you could help answer this sort of question?
         | 
         | JAX's Pytrees are like "immutable structs, with array leaves"
         | -- does PyTorch have a similar concept?
        
           | ein0p wrote:
           | Idk if you need that immutability actually. You could
           | probably reconstruct enough to do this kind of viz from the
           | autograd graph, or capture the graph and intermediates in the
           | forward pass using hooks. My hunch is it should be doable.
        
           | fpgamlirfanboy wrote:
           | > does PyTorch have a similar concept
           | 
           | of course https://github.com/pytorch/pytorch/blob/main/torch/
           | utils/_py...
        
         | _ntka wrote:
         | Isn't JAX the most widely used framework in the GenAI space?
         | Most companies there use it -- Cohere, Anthropic, CharacterAI,
         | xAI, Midjourney etc.
        
           | mistrial9 wrote:
           | just guessing that tech leadership at all of those traces
           | back to Google somehow
        
         | logicchains wrote:
         | Not at Google but currently using Jax to leverage TPUs, because
         | AWS's GPU pricing is eye-gougingly expensive. For the lower-end
         | A10 GPUs, the price-per-gpu for a 4 GPU machine is 1.5x the
         | price for a 1 GPU machine, and the price-per-gpu for a 8 GPU
         | machine is 2x the price of a 1 GPU machine! If you want a A100
         | or H100, the only option is renting an 8 GPU instance. With
         | properly TPU-optimised code you get something like 30-50% cost
         | saving on GCP TPUs compared to AWS (and I say that as someone
         | who otherwise doesn't like Google as a company and would prefer
         | to avoid GCP if there wasn't such a significant cost
         | advantage).
        
         | KeplerBoy wrote:
         | I use it for GPU accelerated signal processing. It really
         | delivers on the promise of "Numpy but for GPU" better than all
         | competing libraries out there.
        
         | sudosysgen wrote:
         | I use it all the time, and there's also a few classes at my uni
         | that use Jax. It's really great for experimentation and
         | research, you can do a lot of things in Jax you just can't in,
         | say, PyTorch.
        
         | MasterScrat wrote:
         | We've built our startup from scratch on JAX, selling text-to-
         | image model finetuning, and it's given us a consistent edge not
         | only in terms of pure performance but also in terms of "dollars
         | per unit of work"
        
         | eli_gottlieb wrote:
         | If JAX had affine_grid() and grid_sample(), I'd be using it
         | instead of PyTorch for my current project.
        
       | catgary wrote:
       | I've only been reading through the docs for a few moments, but
       | I'm pleasantly surprised to find they the authors are using
       | effect handlers to handle effectful computations in ML models. I
       | was in the process of translating a model from torch to Jax using
       | Equinox, this makes me think penzai could be a better choice.
        
         | patrickkidger wrote:
         | I was just reading this too! I think it's a really interesting
         | choice in the design space.
         | 
         | So to elucidate this a little bit, the trade-off is that this
         | is now incompatible with e.g. `jax.grad` or `lax.scan`: you
         | can't compose things in the order
         | `discharge_effect(jax.grad(your_model_here))`, or put an
         | effectful `lax.scan` inside your forward pass, etc. The effect-
         | discharging process only knows how to handle traversing pytree
         | structures. (And they do mention this at the end of their
         | docs.)
         | 
         | This kind of thing was actually something I explicitly
         | considered later on in Equinox, but in part decided against as
         | I couldn't see a way to make that work either. The goal of
         | Equinox was always absolute compatibility with arbitrary JAX
         | code.
         | 
         | Now, none of that should be taken as a bash at Penzai! They've
         | made a different set of trade-offs, and if the above
         | incompatibility doesn't affect your goals then indeed their
         | effect system is incredibly elegant, so certainly give it a
         | try. (Seriously, it's been pretty cool to see the release of
         | Penzai, which explicitly acknowledges how much it's inspired by
         | Equinox.)
        
           | ddjohnson wrote:
           | Author of Penzai here! In idiomatic Penzai usage, you should
           | always discharge all effects before running your model. While
           | it's true you can't do
           | `discharge_effect(jax.grad(your_model_here))`, you can still
           | do `jax.grad(discharge_effect(your_model_here))`, which is
           | probably what you meant to do anyway in most cases. Once
           | you've wrapped your model in a handler layer, it has a pure
           | interface again, which makes it fully compatible with all
           | arbitrary JAX transformations. The intended use of effects is
           | as an internal helper to simplify plumbing of values into and
           | out of layers, not as something that affects the top-level
           | interface of using the model!
           | 
           | (As an example of this, the GemmaTransformer example model
           | uses the SideInput effect internally to do attention masking.
           | But it exposes a pure functional interface by using a handler
           | internally, so you can call it anywhere you could call an
           | Equinox model, and you shouldn't have to think about the
           | effect system at all as a user of the model.)
           | 
           | It's not clear to me what the semantics of ordinary JAX
           | transformations like `lax.scan` should be if the model has
           | side effects. But if you don't have any effects in your
           | model, or if you've explicitly handled them already, then
           | it's perfectly fine to use `lax.scan`. This is similar to how
           | it works in ordinary JAX; if you try to do a `lax.scan` over
           | a function that mutates Python state, you'll probably hit an
           | error or get something unexpected. But if you mutate Python
           | state internally inside `lax.scan`, it works fine.
           | 
           | I'll also note that adding support for higher-order layer
           | combinators (like "layer scan") is something that's on the
           | roadmap! The goal would be to support some of the fancier
           | features of libraries like Flax when you need them, while
           | still admitting a simple purely-functional mental model when
           | you don't.
        
         | ddjohnson wrote:
         | Thanks! This is one of the more experimental design choices I
         | made in designing Penzai, but so far I've found it to be quite
         | useful.
         | 
         | The effect system does come with a few sharp edges at the
         | moment if you want to use JAX transformations inside the
         | forward pass of your model (see my reply to Patrick), but I'm
         | hoping to make it more flexible as time goes on. (Figuring out
         | how effect systems should compose with function transformations
         | is a bit nontrivial!)
         | 
         | Please let me know if you run into any issues using Penzai for
         | your model! (Also, most of Penzai's visualization and patching
         | utilities should work with Equinox too, so you shouldn't
         | necessarily need to fully commit to either one.)
        
       | ubj wrote:
       | Does anyone know if and how well Penzai can work with Diffrax
       | [1]? I currently use Diffrax + Equinox for scientific machine
       | learning. Penzai looks like an interesting alternative to
       | Equinox.
       | 
       | [1]: https://docs.kidger.site/diffrax/
        
         | thatguysaguy wrote:
         | Not sure on the specific combination, but since everything in
         | Jax is functionally pure it's generally really easy to compose
         | libraries. E.g. I've written code which embedded a flax model
         | inside a haiku model without much effort.
        
         | patrickkidger wrote:
         | IIUC then penzai is (deliberately) sacrificing support for
         | higher-order operations like `lax.{while_loop, scan, cond}` or
         | `diffrax.diffeqsolve`, in return for some of the other new
         | features it is trying out (treescope, effects).
         | 
         | So it's slightly more framework-y than Equinox and will not be
         | completely compatible with arbitrary JAX code. However I have
         | already had a collaborator demonstrate that as long as you _don
         | 't_ use any higher-order operations, then treescope will
         | actually work out-of-the-box with Equinox modules!
         | 
         | So I think the answer to your question is "sort of":
         | 
         | * As long as you only try to inspect things that are happening
         | outside of your `diffrax.diffeqsolve` then you should be good
         | to go. And moreover can probably do this simply by using e.g.
         | Penzai's treescope directly alongside your existing Equinox
         | code, without needing to move things over wholesale.
         | 
         | * But anything inside probably isn't supported + if I
         | understand their setup correctly can never be supported. (Not
         | bashing Penzai there, which I think genuinely looks excellent
         | -- I think it's just fundamentally tricky at a technical
         | level.)
        
           | ddjohnson wrote:
           | Author of Penzai here. I think the answer is a bit more
           | nuanced (and closer to "yes") than this:
           | 
           | - If you want to use the treescope pretty-printer or the
           | pz.select tree manipulation utility, those should work out-
           | of-the-box with both Equinox and Diffrax. Penzai's utilities
           | are designed to be as modular as possible (we explicitly try
           | not to be "frameworky") so they support arbitrary JAX
           | pytrees; if you run into any problems with this please file
           | an issue!
           | 
           | - If you want to call a Penzai model inside
           | `diffrax.diffeqsolve`, that should also be fully supported
           | out of the box. Penzai models expose a pure functional
           | interface when called, so you should be able to call a Penzai
           | model anywhere that you'd call an Equinox model. From the
           | perspective of the model user, you should be able to think of
           | the effect system as an implementation detail. Again, if you
           | run into problems here, please file an issue.
           | 
           | - If you want to write your own Penzai layer that uses
           | `diffrax.diffeqsolve` internally, that should also work. You
           | can put arbitrary logic inside a Penzai layer as long as it's
           | pure.
           | 
           | - The specific thing that is _not_ currently fully supported
           | is: (1) defining a higher-order Penzai combinator _layer_
           | that uses `diffrax.diffeqsolve` internally, (2) and having
           | that layer run one of its sublayers inside the
           | `diffrax.diffeqsolve` function, (3) while simultaneously
           | having that internal sublayer use an effect (like random
           | numbers, state, or parameter sharing), (4) where the handler
           | for that effect is placed _outside_ of the combinator layer.
           | This is because the temporary effect implementation node that
           | gets inserted while a handler is running isn 't a JAX array
           | type, so you'll get a JAX error when you try to pass it
           | through a function transformation.
           | 
           | This last case is something I'd like to support as well, but
           | I still need to figure out what the semantics of it should
           | be. (E.g. what does it even mean to solve a differential
           | equation that has a local state variable in it?) I think
           | having side effects inside a transformed function is
           | fundamentally hard to get right!
        
       | pizza wrote:
       | I remember pytorch has some pytree capability, no? So is it safe
       | to say that the any-pytree-compatible modules here are already
       | compatible w/ pytorch?
        
         | ddjohnson wrote:
         | Author here! I didn't know PyTorch had its own pytree system.
         | It looks like it's separate from JAX's pytree registry, though,
         | so Penzai's tooling probably won't work with PyTorch out of the
         | box.
        
         | sillysaurusx wrote:
         | I implemented Jax's pytrees in pure python. You can use it with
         | whatever you want. https://github.com/shawwn/pytreez
         | 
         | The readme is a todo, but the tests are complete. They're the
         | same that Jax itself uses, but zero dependencies.
         | https://github.com/shawwn/pytreez/blob/master/tests/test_pyt...
         | 
         | The concept is simple. The hard part is cross pollination.
         | Suppose you wanted to literally use Jax pytrees with PyTorch.
         | Now you'll have to import Jax, or my library, and register your
         | modules with it. But anything else that ever uses pytrees need
         | to use the same pytree library, because the registry (the thing
         | that keeps track of pytree compatible classes) is in the
         | library you choose. They don't share registries.
         | 
         | A better way of phrasing it is that if you use a jax-style
         | pytree interface, it should work with any other pytree library.
         | But to my knowledge, the only pytree library besides Jax itself
         | is mine here, and only I use it. So when you ask if pytree-
         | compatible modules are compatible with PyTorch, it's equivalent
         | to asking whether PyTorch projects use jax, and the answer
         | tends to be no.
         | 
         | EDIT: perhaps I'm outdated. OP says that PyTorch has pytree
         | functionality now.
         | https://news.ycombinator.com/item?id=40109662 I guess yet again
         | I was ahead of the times by a couple years; happy to see other
         | ecosystems catch up. Hopefully seeing a simple implementation
         | will clarify the tradeoffs.
         | 
         | The best approach for a universal pytree library would be to
         | assume that any class with tree_flatten and tree_unflatten
         | methods are pytreeable, and not require those classes to be
         | explicitly registered. That way you don't have to worry whether
         | you're using Jax or PyTorch pytrees. But I gave up trying to
         | make library-agnostic ML modules; in practice it's better just
         | to choose Jax or PyTorch and be done with it, since making
         | PyTorch modules run in Jax automatically (and vice versa) is a
         | fool's errand (I was the fool, and it was an errand) for many
         | reasons, not the least of which is that Jax builds an explicit
         | computation graph via jax.jit, a feature PyTorch has only
         | recently (and reluctantly) embraced. But of course, that means
         | if you pick the wrong ecosystem, you'll miss out on the best
         | tools -- hello React vs Vue, or Unreal Engine vs Unity, or
         | dozens of other examples.
        
       ___________________________________________________________________
       (page generated 2024-04-21 23:00 UTC)