[HN Gopher] State-space models can learn in-context by gradient ...
___________________________________________________________________
State-space models can learn in-context by gradient descent
Author : dsalaj
Score : 60 points
Date : 2024-10-26 06:37 UTC (2 days ago)
(HTM) web link (arxiv.org)
(TXT) w3m dump (arxiv.org)
| dsalaj wrote:
| Deep state-space models (Deep SSMs) have shown capabilities for
| in-context learning on autoregressive tasks, similar to
| transformers. However, the architectural requirements and
| mechanisms enabling this in recurrent networks remain unclear.
| This study demonstrates that state-space model architectures can
| perform gradient-based learning and use it for in-context
| learning.
| billconan wrote:
| > We show that SSMs with local self-attention, a form of input-
| dependent input processing, can perform in-context learning
| analogously to transformers, i.e. through gradient descent steps
| on an implicit linear regression problem.
|
| I don't understand. The benefit of SSMs is better scalability
| than self-attention. Now this adds self-attention back?
| pitpatagain wrote:
| It adds a very local sliding window attention, the context is
| only 3 adjacent frames per step. They need the access to
| adjacent frames to show the implicit model gradient computation
| but I didn't yet follow the derivation for why this is so.
| roger_ wrote:
| I'd love to see SSMs replace transformers but adapting them to
| non-causal, 2D+ inputs doesn't seem that straightforward.
|
| Is there a non-autoregressive future?
| quantadev wrote:
| > can reproduce the outputs of an implicit linear model with
| least squares loss after one step of gradient descent.
|
| Makes you wonder if we're training LLMs the hard way. For
| example, if computers had been invented before Calculus, we'd
| have been using "Numerical Integration" (iterating the
| differential squares to sum up areas, etc) and "Numerical
| Differentiation" (ditto for calculating slopes).
|
| So I wonder if we're simply in a pre-Calculus-like phase of
| NN/Perceptrons, where we haven't yet realized there's a
| mathematical way to "solve" a bunch of equations simultaneously
| and arrive at the best (or some local minima) model weights for a
| given NN architecture and set of training data.
|
| From a theoretical standpoint it _IS_ a black box problem like
| this where the set of training data goes in, and an array of
| model weights comes out. If I were to guess I 'd bet there'll be
| some kind of "random seed" we can add as input, and for each seed
| we'll get a different (local minima/maxima for model weights).
|
| But I'm not a mathematician and there may be some sort of PROOF
| that what I just said can definitely never be done?
| kbr wrote:
| NNs have complex non-convex loss functions that don't admit a
| closed-form solution. Even for small models, it can be shown
| that it's an NP-complete problem. In fact, even for linear
| regression (least squares), which has a closed-form solution,
| it can be computationally cheaper to run gradient descent since
| finding the closed form solution requires you to calculate and
| invert a large matrix (X^T X).
| quantadev wrote:
| Thanks for that great clarification. I had seen all those
| words before, but just not in that particular order. haha.
|
| Maybe our only hope of doing LLM training runs in a tiny
| amount of time will be from Quantum Computing or even
| Photonic (wave-based) Computing.
| techbro92 wrote:
| There are actually neural networks with explicit optimization
| layers but I don't think these have really had much success.
| derefr wrote:
| So, I'm just a layman when it comes to AI/ML, but I do understand
| computability -- what's possible to do with a given machine, and
| how we can build higher-computational-power primitives out of
| lower-computational-power primitives by plugging those primitives
| together with "glue" like parallel feed-forward chains (e.g. an
| ALU adder's carry bits) and loops over static sub-states of
| execution.
|
| My own mental model for what Transformers _must necessarily_ be
| doing, in order to be able to compute what they compute, given:
|
| 1. the primitives they're made of (for Transformers: matmul a
| learned matrix; vector-add a learned bias vector; normalize;
| softmax)
|
| 2. what those primitives can compute over a single layer
|
| 3. the low-ish total number of layers in a Transformer model
|
| ...is that they were already effectively "state space models" in
| practice. So this doesn't really surprise me!
|
| (To be explicit, my assertion is that, for a given latent space
| between layers N and N+1 in a Transformer model, that latent
| space encodes a set of state variables [think CPU registers] used
| by the Nth serial computation steps of an arbitrary set of
| learned algorithms -- where these algorithms are limited to those
| where every computation step is possible to encode in the form of
| a fused-matmul-plus-vadd, such that the algorithm itself can be
| learned as a depthwise-extruded sequence of weights across the
| layers; and where the learned algorithms can and do share state
| variables, both as inputs and as outputs; and where these state
| variables are all attenuated by an activation probability [in a
| Transformer: attention] such that the algorithms' outputs form a
| pre-multiplied _conditional probability_ of the output given the
| confidence of the inputs -- in turn such that the same state
| variable can be a low-confidence output for one algorithm, and a
| high-confidence output for another algorithm, and the high-
| confidence component of the output will swamp the low-confidence
| output.)
| knowaveragejoe wrote:
| Your intuition is, I think, pretty close to accurate. See this
| paper from earlier this year:
|
| > While Transformers have been the main architecture behind
| deep learning's success in language modeling, state-space
| models (SSMs) such as Mamba have recently been shown to match
| or outperform Transformers at small to medium scale. We show
| that these families of models are actually quite closely
| related, and develop a rich framework of theoretical
| connections between SSMs and variants of attention, connected
| through various decompositions of a well-studied class of
| structured semiseparable matrices. Our state space duality
| (SSD) framework allows us to design a new architecture
| (Mamba-2) whose core layer is an a refinement of Mamba's
| selective SSM that is 2-8X faster, while continuing to be
| competitive with Transformers on language modeling.
|
| https://arxiv.org/abs/2405.21060
| eli_gottlieb wrote:
| >Our key insight is that the diagonal linear recurrent layer can
| act as a gradient accumulator
|
| So they're sort of reinventing the discrete-time differentiator
| from signal processing, but parameterized neurally?
___________________________________________________________________
(page generated 2024-10-28 23:01 UTC)