A few thoughts attempting to frame some related terminology and concepts around the limits of learning, optimization and generalization in my own head, mostly prompted from this tweet.
Much of machine learning at its core can be boiled down to the problem of attempting to approximate a function.
Practically all supervised learning fits this framing, where in a supervised problem you have examples of input data
Self-supervised learning fits here too, and LLMs are a good example of such a case, where we are trying to determine a function that takes a sequence of tokens as input and attempts to predict the next token, and while we sometimes talk about just feeding it “data”, we are creating a structured supervised learning problem out of that data.
We do not know
In deep learning, we do this by defining a seperate
To solve our function approximation problem we turn it into an optimization problem and use stochastic gradient descent.
We have a set of parameters that change the shape of our function
The local search process is iterative, in each iteration (usually a batch of observations) we change our parameters so that our loss goes down a little bit. Thanks to
This is in the general sense of the word a “search” - a guided search process that leverages gradients in order order to find a good or optimal set of parameters.
The form of the
It is not very hard to create an architecture that can in theory represent any function, or at least any turing-computable function (with a given size / memory limits). RNNs are considered to be turing-complete, even if not particularly efficient at it.
The challenge in machine learning however is not coming up with a
Choosing the right
Transformers are particularly great at learning functions over sequences of tokens such as language and even perform very well on vision tasks too. Some might even argue that this is the only broad form of
Over the years there have been many tricks that are consistently applied to deep neural networks that don’t have a significant impact (if any) on the expressiveness of
All of these tricks serve to improve the process of moving toward an optimal set of parameters.
Back to the tweet.
Deep learning has a tendency to learn statistical associations in the data. It will cheat whenever it can - if there is a simple statistical trick, it will use it.
On the Paradox of Learning to Reason from Data explores logical reasoning problems as a possible example of where transformers fail to learn the underlying function we are trying to approximate. The paper presents a set of logical reasoning problems that are simple enough for a human to solve, but that a transformer model (BERT) fails to learn.
What’s actually at fault here? The paper above shows that there exists some set of parameters for the BERT architecture they use that can solve the logic problems presented to it perfectly, so the model has enough expressiveness to represent the function we are trying to learn - we know there is a solution in the parameter space.
The problem is that we fail to find that solution. This inability to find the solution we are looking for must be due to some combination of the approach of stochastic gradient descent or the form of our function
In all of the above we are looking for a set of parameters that shapes our function in such a way that it approximates the observed data well. Clearly if we know
We can quite easily establish a very naive baseline for learning. Assuming our space of possible parameters are enumerable and we know that our
This is obviously not practical in non-trivial cases as the search space grows exponentially with the number of parameters we have, and usually there are quite a lot of parameters. Even if each parameter was a 1-bit which isn’t ridiculous in light of work such as 1-bit LLMs, a 2^7B parameter space for a modest LLM is still an unfathomably large number.
In general, the problem of black box optimization (or derivative free optimization) is one in which we can’t do gradient descent. We have parameters, but there is no way to know how to change them. What are your options in this kind of regime when it comes to learning, can we do better than blind search?
▲ ┌┐
│ ┌┐ ││
│ ││ ││
│ ││ ││ ┌┐
│ ┌┐ ││ ││ ┌┐ ││
│ ││ ││ ┌┐ ││ ││ ││
│ ││ ││ ││ ┌┐ ││ ││ ││
│ ││ ┌┐ ││ ││ ││ ││ ││ ┌┐ ││
└─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─►
▲ ┌┐
│ ┌┐ ││ ┌┐
│ ││ ││ ││
│ ┌┐ ││ ││ ││ ┌┐
│ ┌┐ ││ ││ ││ ││ ││
│ ││ ││ ││ ││ ││ ││ ┌┐
│ ││ ││ ││ ││ ││ ││ ││ ┌┐
│ ││ ││ ││ ││ ││ ││ ││ ││ ┌┐
└─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─►
First example above has several local optima, with no smoothness, or a close to random relationship between x and y - not very good for gradient descent.
Evolution is the obvious example of a black box optimization problem. There’s no analytical way to determine a good change of parameters (genes), so start with something, change things randomly and propagate winners. Evolution, or genetic algorithms make an assumption that parameters close within the space of parameters will have a more similar fitness than those that are far away from each other.
Without this assumption then really you don’t have any options but to blind search. A lot of cryptography in general depends on the idea that really your only approach to solving the problem is a brute force search over the parameter space, and this is effectively what you are doing when you mine for bitcoins.
In the above cases, we are trying to fit a set of observations that presumably have already collected. But an important objective of machine learning is to come up with solutions that generalize well. That is, our solution / model / function approximation also does a good job when it’s presented with data that it hasn’t seen before.
Perfectly predicting the available data is often not that hard. The trivial case here is nearest neighbor search, which effectively just stores the complete mapping of observations and therefore can reproduce them with perfect accuracy given enough memory. Neural networks with enough parameters will often learn to predict the training data with 100% accuracy. In the reasoning with BERT paper, the model is able to learn the training data well, even if when presented with new data it hasn’t seen before, it fails.
There’s not always a way to know if your solution generalizes to new data points well, particularly if those data points are not some interpolation of the data that you have seen (out of distribution). When we can, we hold back some data that we don’t use to learn our approximation of the function, and then test the trained model on this without training to see how it performs.
We have some intuitions and heuristics here too, like preferring simpler models if they explain the data as well - and this isn’t a concept just constrained to machine learning, but something that is generally considered good in science itself when it comes to trying to understand the world around us - a simpler model, provided it doesn’t give up any accuracy - is generally preferred.
But overall, we cannot directly optimize for generalization, only proxies for it that work more often than not, or bake structure into our f` that encourages generalization more often than not - and that seems to be the best we can hope for.
Can we definitely create an architecture or