[HN Gopher] MeZO: Fine-Tuning Language Models with Just Forward ...
       ___________________________________________________________________
        
       MeZO: Fine-Tuning Language Models with Just Forward Passes
        
       Author : behnamoh
       Score  : 80 points
       Date   : 2023-06-06 17:49 UTC (5 hours ago)
        
 (HTM) web link (github.com)
 (TXT) w3m dump (github.com)
        
       | ssivark wrote:
       | Some context, for those who might be getting carried away:
       | 
       | The essential idea here is estimating gradients using numerical
       | (forward pass) function evaluation. The reason it's called
       | "zeroth order" is that it's a much much worse approximation to
       | the gradient than first order methods (Eg: computing the gradient
       | via back propagation). Here's the catch: Bad gradient estimates
       | mean that the optimization takes much longer -- in this paper,
       | MeZO has been allowed 100x as many optimization steps as the
       | baseline fine tuning method. That is why this method isn't
       | commonly used despite being one of the oldest / simplest / most
       | obvious.
       | 
       | That it seems to still be (empirically) good enough for fine
       | tuning is certainly interesting & encouraging. Further, 100x
       | optimization steps might even be a worthwhile trade off in
       | practice (comparing hardware availability, costs, etc)! But at
       | the moment it seems like a pretty niche improvement, unless this
       | phenomenon turns out to generalize for some principled reason.
        
         | gwern wrote:
         | Other points to note: this is a lot like evolution strategies
         | or CMA-ES if you're familiar with those, so inherits their
         | strengths and weaknesses; it's sample/compute-inefficient
         | because it doesn't use real gradients, but that also means you
         | can use non-differentiable tools, so this might be useful in a
         | RL context or a tool/API context, which are increasingly big
         | deals for LLM usecases.
         | 
         | And it's a lot less sample/compute-inefficient than you might
         | expect from such low-quality optimization applied to such
         | highly-parameterized models; so this is another example of the
         | blessings of scale, and how very large neural net models are,
         | in some sense, simpler, more linear/interpretible, more
         | generalizable, and more efficient than they 'ought' to be.
        
         | oersted wrote:
         | Since you seem knowledgeable about the topic and have explained
         | it quite well, I'd like to ask a follow-up question.
         | 
         | Could you summarize how this "zeroth order" gradient
         | approximation works? In simple terms, how do we go from a
         | measurement of the forward-pass output error (loss), to knowing
         | in which direction to nudge each parameter to approach a
         | minimum?
         | 
         | Is it some kind of swarm optimization (not sure if it's the
         | right term), where the local neighborhood in parameter space is
         | sampled to approximate the gradient?
         | 
         | If that is the case, it's not just that more passes are
         | required due to bad gradient approximations, but they also need
         | multiple passes to get each approximate gradient (presumably
         | many of them due to the high dimensionality).
        
           | ssivark wrote:
           | You evaluate the function at two neighboring points, and the
           | difference of function values (divided by the distance
           | between the two points) gives you an estimate of the gradient
           | projected along the direction between the two points. It's
           | literally the definition of the derivative -- but you cannot
           | practically take an infinitesimal step, so the points are a
           | finite distance away (substantially bigger than machine
           | precision).
           | 
           | Another way to think about it -- you choose a random line
           | through your current location, and test neighboring points in
           | two opposite directions, and step in the direction of
           | whichever seems better. That's why it costs as much as two
           | forward passes. If you squint a little, the random choice of
           | direction (line) makes it look kinda like evolutionary
           | search.
           | 
           | This is a textbook method, so I'm sure there must be some
           | description on the web with pretty pictures and clearly
           | worked out math -- but I'm AFK right now and can't find a
           | good one through my phone :-(
        
             | nighthawk454 wrote:
             | Maybe check out this paper, the first page has a decent
             | diagram and caption!
             | 
             | https://arxiv.org/abs/2006.06224
        
         | bloaf wrote:
         | Those method would work on recurrent neutral nets, not just
         | feed forward.
        
       | brucethemoose2 wrote:
       | A repost, but it should be reposted. This is amazing, maybe even
       | scary.
        
         | gwern wrote:
         | https://arxiv.org/abs/2305.17333
        
         | empalms wrote:
         | Could you point to the OP(s) please? No luck here with HN
         | search
        
           | icpmacdo wrote:
           | Maybe a reference to GGML thread
        
         | pm wrote:
         | As someone who's not familiar enough with LLMs to deduce why
         | this is amazing or scary, would you kindly explain as to why
         | this is so?
        
           | heyitsguay wrote:
           | It's not. That's part of the uninformed AI hype train that
           | will consistently be posting first on AI stuff for the next 6
           | months.
           | 
           | Right now the main bottleneck for LLM size is GPU memory
           | (VRAM). Training requires much more VRAM than inference,
           | which limits the ability for entities that aren't Google or
           | OpenAI-scale to finetune models (aka do a little more
           | training on your custom dataset).
           | 
           | The paper here suggests that one can actually finetune LLMs
           | with inference-sized VRAM usage instead of training-sized
           | VRAM usage. If true, it will be possible to fine tune larger
           | models on smaller (though still expensive) GPUs -- like a
           | single 3090 instead of 1x or 8xA100s. So, more people can
           | create more customized models.
        
             | vvladymyrov wrote:
             | Inference can be done on CPU+RAM, but it is much slower
             | (like tens of seconds per token). So reducing the amount of
             | memory used by model during training would reduce the
             | number of compute operations potentially could make CPU+RAM
             | more suitable for fine tuning within reasonable amount of
             | time too. Basically 12x less GPU memory requirement also
             | translates to 12x less compute operations (doing compute on
             | CPU allows less parallelism them on GPU).
             | 
             | The paper doesn't focus on CPU or GPU training time
             | improvements - I'd assume there is no significant
             | improvement in GPU training case. For CPU it is logical to
             | expect 12x training speed improvement, but it is still too
             | slow to be consistent practically useful.
        
               | gliptic wrote:
               | > For CPU it is logical to expect 12x training speed
               | improvement, but it is still too slow to be consistent
               | practically useful.
               | 
               | I don't see what you base this on. MeZO trades one back-
               | propagation pass for another forward pass. Why would that
               | be 12x faster? It's also clear the convergence rate is
               | slower than plain SGD (never mind AdamW) by a factor
               | proportional to the effective rank of the Hessian.
        
           | brucethemoose2 wrote:
           | heyitsguay is correct, but in addition:
           | 
           | - The researchers didn't even explore quantization. In theory
           | 4 bit quant would allow for training on even more modest
           | hardware.
           | 
           | - Memory use aside, forward pass only is potentially a big
           | training speed increase.
           | 
           | - This method is very amenable to decentralized training.
           | 
           | It scares me because it feels like a gateway to networked,
           | self training LLMs on commodity hardware. I thought this was
           | a long way away... now it doesn't feel so far away.
        
             | gliptic wrote:
             | > - Memory use aside, forward pass only is potentially a
             | big training speed increase.
             | 
             | Forward-pass only doesn't mean it's faster. It converges
             | much slower than even SGD. It is a memory-time tradeoff,
             | but that's it.
        
           | [deleted]
        
       ___________________________________________________________________
       (page generated 2023-06-06 23:00 UTC)