[HN Gopher] TokenFormer: Rethinking Transformer Scaling with Tok...
       ___________________________________________________________________
        
       TokenFormer: Rethinking Transformer Scaling with Tokenized Model
       Parameters
        
       Author : og_kalu
       Score  : 110 points
       Date   : 2024-11-01 14:10 UTC (8 hours ago)
        
 (HTM) web link (arxiv.org)
 (TXT) w3m dump (arxiv.org)
        
       | cs702 wrote:
       | The authors factorize every weight matrix with an attention
       | mechanism:                 weight = attention(token_query,
       | weight_keys, weight_values).
       | 
       | In other words, they query weight_keys to fetch the
       | weight_values, and mix them to compute each weight on the spot.
       | 
       | Increasing model size becomes a matter of adding more weight_keys
       | and weight_values, and incrementally training them.
       | 
       | Simple, clever, and it seems to work well. Beautiful.
        
         | szcs wrote:
         | There is a particularly nice geometric interpretation of
         | attention I just realised recently in a flash of enlightenment,
         | best explained with an interactive Desmos plot (black dot is
         | draggable):
         | 
         | https://www.desmos.com/calculator/3rtqsyapxo
         | 
         | The above assumes the columns of K are normalised but bear with
         | me. K and V together form a vector database. V are the
         | payloads, each row containing a vector of data. K describes the
         | position of these points in space, on the surface of a
         | hypershpere. The query vector describes the query into the
         | database: the vector direction describes the point in space
         | that's being queried, the vector magnitude describes the radius
         | of the query. The result is the weighted average of vectors
         | from V, weighted by their distance from the query vector scaled
         | by the query radius (which has a smooth Gaussian falloff). A
         | recent paper from Nvidia I recommend, which derives a
         | significant speedup by normalising vectors to a hypershpere:
         | https://arxiv.org/abs/2410.01131v1
        
           | liuliu wrote:
           | Yeah, I believe this intuition first introduced by the Neural
           | Turing Machine line-of-work and later simplified into AIAYN
           | paper (NTM maintains "external memory" a.k.a. weight_keys,
           | weight_values here).
           | 
           | Disclaimer: these are from my memory, which can be wrong
           | entirely.
        
           | jdthedisciple wrote:
           | It looks fascinating, but i don't understand it. I'm haven't
           | gone yet deeply into the theory of attention networks.
           | 
           | Can you explain the desmos plot in simple terms?
        
         | anon291 wrote:
         | I believe there have been studies showing that the attention
         | mechanism allows estimation of gradients for one-shot learning
         | (i.e, based on what you tell the model you want in the input,
         | it will use attention to 'update' the weights of the linear
         | layers to 'learn' new information). This seems to be taking
         | that one step further and just using attention for the weight
         | estimations itself. The key insight here is that by adding more
         | tokens to the weight estimation calculation, you can get more
         | degrees of freedom.
         | 
         | Total aside, but imagining how many levels of functions are
         | present in the calculation of each activation here, and
         | thinking about how regular old differentiation and gradient
         | descent actually work to train these nested parameters, is
         | truly amazing, in my opinion.
        
           | cs702 wrote:
           | Yeah. This thing is "assembling a different transformer" on
           | the spot for each token.
           | 
           | If one thinks about it for more than a moment, it's kind of
           | incredible that it works.
        
             | 0-_-0 wrote:
             | I think the same about regular neutral networks
        
       | davesque wrote:
       | Seems like a big deal. I feel like this could enable a new level
       | of modularity and compatibility between publicly available weight
       | sets, assuming they use similar channel dimensions. Maybe it also
       | provides a nice formalism for thinking about fine tuning, where
       | you could adopt certain heuristics for adding/removing key-value
       | pairs from the Pattention layers.
       | 
       | One interesting thing to note: sounds like model scaling happens
       | on the fly by adding key-value pairs as rows in the K and V
       | matrices on the Pattention layer. That suggests that weights
       | represented by tokens in the first rows may be more important
       | than weights in later rows. There may be a lot you could do with
       | that ordering of weights in terms of pruning and such.
        
         | valine wrote:
         | Unless I'm reading it wrong I don't think rows matter.
         | Attention doesn't take into account sequence position natively,
         | that's why positional encodings exist.
        
           | davesque wrote:
           | I'm talking about the rows in the new K and V matrices
           | introduced by the paper, not rows in the input sequence. The
           | ordering of rows in the new K and V matrices does matter in
           | the sense that rows that appear further down were added later
           | in the training process to add new parameter tokens during
           | scaling. So those newer parameters _may_ represent knowledge
           | that is less fundamental and more about fine tuning on the
           | training set.
        
       | c0g wrote:
       | Surprised not to see a comparison to
       | https://paperswithcode.com/paper/augmenting-self-attention-w...
        
       | a_wild_dandan wrote:
       | This could be revolutionary. The PPL/compute graphs are damning.
       | If the Transformer is a function, then the TokenFormer feels like
       | a higher-order function. Perhaps this approach is a natural
       | direction for producing System Two reasoning? There's so much to
       | digest here...
        
       | valine wrote:
       | I would like to see a comparison for the inference time compute
       | between a regular transformer and this. I'm assuming token/s is
       | lower since you need to compute the weights of the model for each
       | token prior to the actual attention calculations for the sequence
       | position.
        
         | logicchains wrote:
         | Isn't that figure 5 in the paper? It's for training not
         | inference, but presumably if training is faster then inference
         | would be too. Because they don't increase the dimension of the
         | text tokens when scaling up, which reduces the compute needed
         | for attention. But potentially limits how well the text token
         | attention can keep track of things, because it's got less space
         | for passing things along.
        
       | goldenshale wrote:
       | This is a great idea. Being able to dynamically scale up model
       | sizes as datasets and use cases expand without needing to retrain
       | from scratch could enable a Cambrian explosion of interesting
       | stuff building on top of a Llama type model trained in this way.
        
       | logicchains wrote:
       | Seems this would naturally translate into a mixture of experts by
       | using a "hard" attention function so that only a fixed amount of
       | weight tokens get included in the calculation.
        
       | davesque wrote:
       | Seems like a lot of existing models could be converted to this
       | token parameter representation.
        
       | ml_thoughts wrote:
       | This seems closely related to the "Mixtral" approach of a
       | mixture-of-experts transformer [1]... I'm not claiming the
       | approach is not original, it just helped me understand what was
       | going on.
       | 
       | Consider a case of two "experts" or two "value parameter tokens."
       | 
       | The mixture of experts has a "router" network that provides a
       | weight to each expert (through a softmax) conditional on an
       | input. The output is a (sparse) weighted sum of the outputs of
       | the experts.
       | 
       | The TokenFormer has an "attention" layer combines the token and a
       | key value to provide a weight to each "value parameter" token.
       | A(B+C) = AB + AC definitionally, so this is like applying a
       | weighted sum of distinct transformations.
       | 
       | I think the differences are: a) where the non-linearity hits (the
       | above description doesn't consider an activation function), b)
       | this attention softmax is not (necessarily) sparse, c) that
       | "mixtral" networks only replace the feed-forward components of
       | the layer, and d) that extending a "mixtral" approach would
       | require re-training the "router" layers.
       | 
       | It seems like (d) is maybe the nicest feature here... my
       | intuition would think (a) doesn't matter much, (b) is debatable
       | (how close a sparse-MoE can approximate a dense-MoE), (c) has
       | probably been tried (guessing the ffwd limitation was just "more-
       | bang-for-buck-given-parameters" not an oversight)...
       | 
       | ... I wonder, though, if there might be diminishing returns here
       | (I believe that Mixture-of-Experts tends to struggle with
       | imbalanced "winner-take-all" dynamics, since "early" winners get
       | more gradient signal to improve their weights) and how different
       | this would have been from going from 3x7B to a 8x7B to a 24x7B
       | training approach (with a "retrain routing networks" step).
       | 
       | [1] https://arxiv.org/abs/2401.04088
        
       | mentalically wrote:
       | Eventually people will figure out how to nest neural networks in
       | the nodes and edges of an arbitrary graph.
        
       ___________________________________________________________________
       (page generated 2024-11-01 23:00 UTC)