[HN Gopher] TokenFormer: Rethinking Transformer Scaling with Tok...
___________________________________________________________________
TokenFormer: Rethinking Transformer Scaling with Tokenized Model
Parameters
Author : og_kalu
Score : 110 points
Date : 2024-11-01 14:10 UTC (8 hours ago)
(HTM) web link (arxiv.org)
(TXT) w3m dump (arxiv.org)
| cs702 wrote:
| The authors factorize every weight matrix with an attention
| mechanism: weight = attention(token_query,
| weight_keys, weight_values).
|
| In other words, they query weight_keys to fetch the
| weight_values, and mix them to compute each weight on the spot.
|
| Increasing model size becomes a matter of adding more weight_keys
| and weight_values, and incrementally training them.
|
| Simple, clever, and it seems to work well. Beautiful.
| szcs wrote:
| There is a particularly nice geometric interpretation of
| attention I just realised recently in a flash of enlightenment,
| best explained with an interactive Desmos plot (black dot is
| draggable):
|
| https://www.desmos.com/calculator/3rtqsyapxo
|
| The above assumes the columns of K are normalised but bear with
| me. K and V together form a vector database. V are the
| payloads, each row containing a vector of data. K describes the
| position of these points in space, on the surface of a
| hypershpere. The query vector describes the query into the
| database: the vector direction describes the point in space
| that's being queried, the vector magnitude describes the radius
| of the query. The result is the weighted average of vectors
| from V, weighted by their distance from the query vector scaled
| by the query radius (which has a smooth Gaussian falloff). A
| recent paper from Nvidia I recommend, which derives a
| significant speedup by normalising vectors to a hypershpere:
| https://arxiv.org/abs/2410.01131v1
| liuliu wrote:
| Yeah, I believe this intuition first introduced by the Neural
| Turing Machine line-of-work and later simplified into AIAYN
| paper (NTM maintains "external memory" a.k.a. weight_keys,
| weight_values here).
|
| Disclaimer: these are from my memory, which can be wrong
| entirely.
| jdthedisciple wrote:
| It looks fascinating, but i don't understand it. I'm haven't
| gone yet deeply into the theory of attention networks.
|
| Can you explain the desmos plot in simple terms?
| anon291 wrote:
| I believe there have been studies showing that the attention
| mechanism allows estimation of gradients for one-shot learning
| (i.e, based on what you tell the model you want in the input,
| it will use attention to 'update' the weights of the linear
| layers to 'learn' new information). This seems to be taking
| that one step further and just using attention for the weight
| estimations itself. The key insight here is that by adding more
| tokens to the weight estimation calculation, you can get more
| degrees of freedom.
|
| Total aside, but imagining how many levels of functions are
| present in the calculation of each activation here, and
| thinking about how regular old differentiation and gradient
| descent actually work to train these nested parameters, is
| truly amazing, in my opinion.
| cs702 wrote:
| Yeah. This thing is "assembling a different transformer" on
| the spot for each token.
|
| If one thinks about it for more than a moment, it's kind of
| incredible that it works.
| 0-_-0 wrote:
| I think the same about regular neutral networks
| davesque wrote:
| Seems like a big deal. I feel like this could enable a new level
| of modularity and compatibility between publicly available weight
| sets, assuming they use similar channel dimensions. Maybe it also
| provides a nice formalism for thinking about fine tuning, where
| you could adopt certain heuristics for adding/removing key-value
| pairs from the Pattention layers.
|
| One interesting thing to note: sounds like model scaling happens
| on the fly by adding key-value pairs as rows in the K and V
| matrices on the Pattention layer. That suggests that weights
| represented by tokens in the first rows may be more important
| than weights in later rows. There may be a lot you could do with
| that ordering of weights in terms of pruning and such.
| valine wrote:
| Unless I'm reading it wrong I don't think rows matter.
| Attention doesn't take into account sequence position natively,
| that's why positional encodings exist.
| davesque wrote:
| I'm talking about the rows in the new K and V matrices
| introduced by the paper, not rows in the input sequence. The
| ordering of rows in the new K and V matrices does matter in
| the sense that rows that appear further down were added later
| in the training process to add new parameter tokens during
| scaling. So those newer parameters _may_ represent knowledge
| that is less fundamental and more about fine tuning on the
| training set.
| c0g wrote:
| Surprised not to see a comparison to
| https://paperswithcode.com/paper/augmenting-self-attention-w...
| a_wild_dandan wrote:
| This could be revolutionary. The PPL/compute graphs are damning.
| If the Transformer is a function, then the TokenFormer feels like
| a higher-order function. Perhaps this approach is a natural
| direction for producing System Two reasoning? There's so much to
| digest here...
| valine wrote:
| I would like to see a comparison for the inference time compute
| between a regular transformer and this. I'm assuming token/s is
| lower since you need to compute the weights of the model for each
| token prior to the actual attention calculations for the sequence
| position.
| logicchains wrote:
| Isn't that figure 5 in the paper? It's for training not
| inference, but presumably if training is faster then inference
| would be too. Because they don't increase the dimension of the
| text tokens when scaling up, which reduces the compute needed
| for attention. But potentially limits how well the text token
| attention can keep track of things, because it's got less space
| for passing things along.
| goldenshale wrote:
| This is a great idea. Being able to dynamically scale up model
| sizes as datasets and use cases expand without needing to retrain
| from scratch could enable a Cambrian explosion of interesting
| stuff building on top of a Llama type model trained in this way.
| logicchains wrote:
| Seems this would naturally translate into a mixture of experts by
| using a "hard" attention function so that only a fixed amount of
| weight tokens get included in the calculation.
| davesque wrote:
| Seems like a lot of existing models could be converted to this
| token parameter representation.
| ml_thoughts wrote:
| This seems closely related to the "Mixtral" approach of a
| mixture-of-experts transformer [1]... I'm not claiming the
| approach is not original, it just helped me understand what was
| going on.
|
| Consider a case of two "experts" or two "value parameter tokens."
|
| The mixture of experts has a "router" network that provides a
| weight to each expert (through a softmax) conditional on an
| input. The output is a (sparse) weighted sum of the outputs of
| the experts.
|
| The TokenFormer has an "attention" layer combines the token and a
| key value to provide a weight to each "value parameter" token.
| A(B+C) = AB + AC definitionally, so this is like applying a
| weighted sum of distinct transformations.
|
| I think the differences are: a) where the non-linearity hits (the
| above description doesn't consider an activation function), b)
| this attention softmax is not (necessarily) sparse, c) that
| "mixtral" networks only replace the feed-forward components of
| the layer, and d) that extending a "mixtral" approach would
| require re-training the "router" layers.
|
| It seems like (d) is maybe the nicest feature here... my
| intuition would think (a) doesn't matter much, (b) is debatable
| (how close a sparse-MoE can approximate a dense-MoE), (c) has
| probably been tried (guessing the ffwd limitation was just "more-
| bang-for-buck-given-parameters" not an oversight)...
|
| ... I wonder, though, if there might be diminishing returns here
| (I believe that Mixture-of-Experts tends to struggle with
| imbalanced "winner-take-all" dynamics, since "early" winners get
| more gradient signal to improve their weights) and how different
| this would have been from going from 3x7B to a 8x7B to a 24x7B
| training approach (with a "retrain routing networks" step).
|
| [1] https://arxiv.org/abs/2401.04088
| mentalically wrote:
| Eventually people will figure out how to nest neural networks in
| the nodes and edges of an arbitrary graph.
___________________________________________________________________
(page generated 2024-11-01 23:00 UTC)