[HN Gopher] MeZO: Fine-Tuning Language Models with Just Forward ...
___________________________________________________________________
MeZO: Fine-Tuning Language Models with Just Forward Passes
Author : behnamoh
Score : 80 points
Date : 2023-06-06 17:49 UTC (5 hours ago)
(HTM) web link (github.com)
(TXT) w3m dump (github.com)
| ssivark wrote:
| Some context, for those who might be getting carried away:
|
| The essential idea here is estimating gradients using numerical
| (forward pass) function evaluation. The reason it's called
| "zeroth order" is that it's a much much worse approximation to
| the gradient than first order methods (Eg: computing the gradient
| via back propagation). Here's the catch: Bad gradient estimates
| mean that the optimization takes much longer -- in this paper,
| MeZO has been allowed 100x as many optimization steps as the
| baseline fine tuning method. That is why this method isn't
| commonly used despite being one of the oldest / simplest / most
| obvious.
|
| That it seems to still be (empirically) good enough for fine
| tuning is certainly interesting & encouraging. Further, 100x
| optimization steps might even be a worthwhile trade off in
| practice (comparing hardware availability, costs, etc)! But at
| the moment it seems like a pretty niche improvement, unless this
| phenomenon turns out to generalize for some principled reason.
| gwern wrote:
| Other points to note: this is a lot like evolution strategies
| or CMA-ES if you're familiar with those, so inherits their
| strengths and weaknesses; it's sample/compute-inefficient
| because it doesn't use real gradients, but that also means you
| can use non-differentiable tools, so this might be useful in a
| RL context or a tool/API context, which are increasingly big
| deals for LLM usecases.
|
| And it's a lot less sample/compute-inefficient than you might
| expect from such low-quality optimization applied to such
| highly-parameterized models; so this is another example of the
| blessings of scale, and how very large neural net models are,
| in some sense, simpler, more linear/interpretible, more
| generalizable, and more efficient than they 'ought' to be.
| oersted wrote:
| Since you seem knowledgeable about the topic and have explained
| it quite well, I'd like to ask a follow-up question.
|
| Could you summarize how this "zeroth order" gradient
| approximation works? In simple terms, how do we go from a
| measurement of the forward-pass output error (loss), to knowing
| in which direction to nudge each parameter to approach a
| minimum?
|
| Is it some kind of swarm optimization (not sure if it's the
| right term), where the local neighborhood in parameter space is
| sampled to approximate the gradient?
|
| If that is the case, it's not just that more passes are
| required due to bad gradient approximations, but they also need
| multiple passes to get each approximate gradient (presumably
| many of them due to the high dimensionality).
| ssivark wrote:
| You evaluate the function at two neighboring points, and the
| difference of function values (divided by the distance
| between the two points) gives you an estimate of the gradient
| projected along the direction between the two points. It's
| literally the definition of the derivative -- but you cannot
| practically take an infinitesimal step, so the points are a
| finite distance away (substantially bigger than machine
| precision).
|
| Another way to think about it -- you choose a random line
| through your current location, and test neighboring points in
| two opposite directions, and step in the direction of
| whichever seems better. That's why it costs as much as two
| forward passes. If you squint a little, the random choice of
| direction (line) makes it look kinda like evolutionary
| search.
|
| This is a textbook method, so I'm sure there must be some
| description on the web with pretty pictures and clearly
| worked out math -- but I'm AFK right now and can't find a
| good one through my phone :-(
| nighthawk454 wrote:
| Maybe check out this paper, the first page has a decent
| diagram and caption!
|
| https://arxiv.org/abs/2006.06224
| bloaf wrote:
| Those method would work on recurrent neutral nets, not just
| feed forward.
| brucethemoose2 wrote:
| A repost, but it should be reposted. This is amazing, maybe even
| scary.
| gwern wrote:
| https://arxiv.org/abs/2305.17333
| empalms wrote:
| Could you point to the OP(s) please? No luck here with HN
| search
| icpmacdo wrote:
| Maybe a reference to GGML thread
| pm wrote:
| As someone who's not familiar enough with LLMs to deduce why
| this is amazing or scary, would you kindly explain as to why
| this is so?
| heyitsguay wrote:
| It's not. That's part of the uninformed AI hype train that
| will consistently be posting first on AI stuff for the next 6
| months.
|
| Right now the main bottleneck for LLM size is GPU memory
| (VRAM). Training requires much more VRAM than inference,
| which limits the ability for entities that aren't Google or
| OpenAI-scale to finetune models (aka do a little more
| training on your custom dataset).
|
| The paper here suggests that one can actually finetune LLMs
| with inference-sized VRAM usage instead of training-sized
| VRAM usage. If true, it will be possible to fine tune larger
| models on smaller (though still expensive) GPUs -- like a
| single 3090 instead of 1x or 8xA100s. So, more people can
| create more customized models.
| vvladymyrov wrote:
| Inference can be done on CPU+RAM, but it is much slower
| (like tens of seconds per token). So reducing the amount of
| memory used by model during training would reduce the
| number of compute operations potentially could make CPU+RAM
| more suitable for fine tuning within reasonable amount of
| time too. Basically 12x less GPU memory requirement also
| translates to 12x less compute operations (doing compute on
| CPU allows less parallelism them on GPU).
|
| The paper doesn't focus on CPU or GPU training time
| improvements - I'd assume there is no significant
| improvement in GPU training case. For CPU it is logical to
| expect 12x training speed improvement, but it is still too
| slow to be consistent practically useful.
| gliptic wrote:
| > For CPU it is logical to expect 12x training speed
| improvement, but it is still too slow to be consistent
| practically useful.
|
| I don't see what you base this on. MeZO trades one back-
| propagation pass for another forward pass. Why would that
| be 12x faster? It's also clear the convergence rate is
| slower than plain SGD (never mind AdamW) by a factor
| proportional to the effective rank of the Hessian.
| brucethemoose2 wrote:
| heyitsguay is correct, but in addition:
|
| - The researchers didn't even explore quantization. In theory
| 4 bit quant would allow for training on even more modest
| hardware.
|
| - Memory use aside, forward pass only is potentially a big
| training speed increase.
|
| - This method is very amenable to decentralized training.
|
| It scares me because it feels like a gateway to networked,
| self training LLMs on commodity hardware. I thought this was
| a long way away... now it doesn't feel so far away.
| gliptic wrote:
| > - Memory use aside, forward pass only is potentially a
| big training speed increase.
|
| Forward-pass only doesn't mean it's faster. It converges
| much slower than even SGD. It is a memory-time tradeoff,
| but that's it.
| [deleted]
___________________________________________________________________
(page generated 2023-06-06 23:00 UTC)