[HN Gopher] Autodidax: JAX core from scratch
       ___________________________________________________________________
        
       Autodidax: JAX core from scratch
        
       Author : sva_
       Score  : 68 points
       Date   : 2023-02-09 22:12 UTC (2 days ago)
        
 (HTM) web link (jax.readthedocs.io)
 (TXT) w3m dump (jax.readthedocs.io)
        
       | civilized wrote:
       | This is great if you want to learn JAX, but it seems like not the
       | most efficient way to learn how to implement symbolic
       | differentiation or automated generation of derivative functions
       | on a computer. Is there something that takes a more direct route,
       | maybe with Scheme or something? (I don't care about the language,
       | just whatever presents the least bureaucracy and overhead.)
        
         | 6gvONxR4sf7o wrote:
         | Automatic differentiation is a huge field going back to at
         | least the 1960s, so you're right that this is an idiosyncratic
         | way to do it. If you want to learn about it more generally,
         | "automatic differentiation" is the phrase to google. Tinygrad
         | and micrograd are particularly small libraries you might enjoy
         | looking at.
        
           | civilized wrote:
           | Thanks! It seems like the direct route should be something
           | like "turn the function into it's AST, apply derivative rules
           | to transform the AST, turn the result back into a function".
           | And this JAX post doesn't really speak in those terms, at
           | least not directly.
        
             | 6gvONxR4sf7o wrote:
             | > the direct route should be something like "turn the
             | function into it's AST, apply derivative rules to transform
             | the AST, turn the result back into a function"
             | 
             | A couple points on that:
             | 
             | 1) The most direct route can be even simpler (what's called
             | forward mode differentiation)! You want the derivative of a
             | function in some direction at some point, and you can do
             | that by just passing in the point and the direction. If
             | every function involved knows how to transform the point
             | and transform the direction, then you just evaluate it step
             | by step, no transformations required. This is the "jvp"
             | approach in OP.
             | 
             | 2) Something that is often misunderstood about JAX is that
             | JAX isn't just about taking derivatives. A large part of it
             | is just about transforming functions. Hence its
             | idiosyncrasies. It turns out one of the transformations you
             | can do is exactly what you said: transform it into its AST
             | (called jaxprs IIRC), then transform that into whatever you
             | want (gradients, inverses, parallel computations, JIT
             | compile it, whatever), then turn that back into a function.
             | And that's exactly how the linked post does reverse mode
             | differentiation a couple pages in (IIRC). That flexibility
             | is both what makes JAX's approach so interesting, and what
             | makes JAX such a PITA to debug.
        
       | 6gvONxR4sf7o wrote:
       | I love this. I wish every big library had something like this. It
       | helped me contribute to JAX in the past, and is a great
       | educational resource and source of inspiration for my own tools.
       | 
       | I've tried to find something similar for pytorch and numpy in the
       | past and was let down.
        
         | xavdid wrote:
         | Same! I looked into how pytest worked and got bogged down in
         | its plugin system (which is great, but very distracting when
         | trying to read the code). I've always been curious how the
         | assertion rewriting works.
        
         | sva_ wrote:
         | > I've tried to find something similar for pytorch and numpy in
         | the past and was let down.
         | 
         | Well I got a treat for you then
         | 
         | https://minitorch.github.io/
        
           | 6gvONxR4sf7o wrote:
           | I don't mean how torch's API _could_ be implemented, I mean
           | how pytorch implements it. Do you know which this is?
        
         | rck wrote:
         | I think the closest thing for pytorch is the Karpathy video:
         | 
         | https://www.youtube.com/watch?v=VMj-3S1tku0
         | 
         | There's also an old book on the internals of numpy - not sure
         | how out of date it is though.
        
         | thelastbender12 wrote:
         | But seriously, for open source projects actually looking for
         | contributors (vis-a-vis companies with a developer product they
         | just make open source), there is no better resource.
         | 
         | To understand a framework, you need a mental model of it. Good
         | documentation is helpful, but it seldom walks through why
         | specific design choices were necessary.
        
       ___________________________________________________________________
       (page generated 2023-02-11 23:00 UTC)