Gradient based learning and black box optimization

2024-04-18 / 12 min read

A few thoughts attempting to frame some related terminology and concepts around the limits of learning, optimization and generalization in my own head, mostly prompted from this tweet.

Machine learning as function approximation #

Much of machine learning at its core can be boiled down to the problem of attempting to approximate a function.

Practically all supervised learning fits this framing, where in a supervised problem you have examples of input data paired with outputs and you attempt to learn a function such that:

Self-supervised learning fits here too, and LLMs are a good example of such a case, where we are trying to determine a function that takes a sequence of tokens as input and attempts to predict the next token, and while we sometimes talk about just feeding it “data”, we are creating a structured supervised learning problem out of that data.

We do not know , so we attempt to learn it from the usually many observations of that we have.

In deep learning, we do this by defining a seperate that takes a set of parameters as well as the input , and using gradient descent to learn the parameters.

Gradient descent is guided search over #

To solve our function approximation problem we turn it into an optimization problem and use stochastic gradient descent.

We have a set of parameters that change the shape of our function , we know that somewhere in our parameter space there is a good, or even best solution that we are looking for. We define “good”, or it’s typically it’s inverse in the form of a loss (or, “bad”) function based on how well it successfully approximates the various data points in our training set. Sometimes we also have other losses that measure internal properties of the model and are there to encourage learning models with certain desirable properties, sparsity, regularization, etc.

The local search process is iterative, in each iteration (usually a batch of observations) we change our parameters so that our loss goes down a little bit. Thanks to being differentiable, we can compute a set of changes in our parameters that would make the loss lower. We repeat this process until our loss is low enough or stops going any lower.

This is in the general sense of the word a “search” - a guided search process that leverages gradients in order order to find a good or optimal set of parameters.

Expressiveness vs learnability #

The form of the function - which in deep learning amounts to the architecture of the neural network - defines the possible types of functions can be.

It is not very hard to create an architecture that can in theory represent any function, or at least any turing-computable function (with a given size / memory limits). RNNs are considered to be turing-complete, even if not particularly efficient at it.

The challenge in machine learning however is not coming up with a that is expressive enough, but coming up with a that works well with the applied method of learning (gradient descent) such that it can efficiently and successfully learn (or find) a set of parameters that approximates the function well.

Better s #

Choosing the right is hard, and arguably this is the most common form of research in machine learning / AI. Some notable are convolutional networks, or transformers. The former is designed with an architecture that is used extensively for object recognition, having a desirable translation invariance properly such that:

Transformers are particularly great at learning functions over sequences of tokens such as language and even perform very well on vision tasks too. Some might even argue that this is the only broad form of that we will ever need.

Over the years there have been many tricks that are consistently applied to deep neural networks that don’t have a significant impact (if any) on the expressiveness of but make it more efficient at learning through gradient descent. A few examples:

  • Rectified linear activation functions - a simple to compute non-linear activation function.
  • Drop-out - stochastic zeroing out of connections that acts as a form of regularization and promotes generalization.
  • Residual connections - connections that skip a layer, helping the gradients to flow through deep architectures.
  • Batch and layer normalization - normalizing outputs of layers one way or the other, stabilizing gradients.

All of these tricks serve to improve the process of moving toward an optimal set of parameters.

When gradient descent doesn’t work #

Back to the tweet.

Deep learning has a tendency to learn statistical associations in the data. It will cheat whenever it can - if there is a simple statistical trick, it will use it.

On the Paradox of Learning to Reason from Data explores logical reasoning problems as a possible example of where transformers fail to learn the underlying function we are trying to approximate. The paper presents a set of logical reasoning problems that are simple enough for a human to solve, but that a transformer model (BERT) fails to learn.

What’s actually at fault here? The paper above shows that there exists some set of parameters for the BERT architecture they use that can solve the logic problems presented to it perfectly, so the model has enough expressiveness to represent the function we are trying to learn - we know there is a solution in the parameter space.

The problem is that we fail to find that solution. This inability to find the solution we are looking for must be due to some combination of the approach of stochastic gradient descent or the form of our function , and we mostly attempt to address the problem by coming up with new s rather than changing SGD itself. There is of course no guarantee that SGD will ever find a global optimum.

Blind search, black box optimization #

In all of the above we are looking for a set of parameters that shapes our function in such a way that it approximates the observed data well. Clearly if we know has enough expressibility to represent our target function, then there exists at least one set of parameters that we would consider to be a solution to the problem.

We can quite easily establish a very naive baseline for learning. Assuming our space of possible parameters are enumerable and we know that our has the required expressiveness - the most naive learning algorithm which is guaranteed to work given enough resources - would be to simply go through all possible parameters, evaluating the function against our observed data until we find something with low / the lowest loss.

This is obviously not practical in non-trivial cases as the search space grows exponentially with the number of parameters we have, and usually there are quite a lot of parameters. Even if each parameter was a 1-bit which isn’t ridiculous in light of work such as 1-bit LLMs, a 2^7B parameter space for a modest LLM is still an unfathomably large number.

In general, the problem of black box optimization (or derivative free optimization) is one in which we can’t do gradient descent. We have parameters, but there is no way to know how to change them. What are your options in this kind of regime when it comes to learning, can we do better than blind search?

 ▲                ┌┐            
 │       ┌┐       ││            
 │       ││       ││            
 │       ││       ││       ┌┐   
 │ ┌┐    ││       ││ ┌┐    ││   
 │ ││    ││ ┌┐    ││ ││    ││   
 │ ││    ││ ││ ┌┐ ││ ││    ││   
 │ ││ ┌┐ ││ ││ ││ ││ ││ ┌┐ ││   
 ▲          ┌┐                  
 │       ┌┐ ││ ┌┐               
 │       ││ ││ ││               
 │    ┌┐ ││ ││ ││ ┌┐            
 │ ┌┐ ││ ││ ││ ││ ││            
 │ ││ ││ ││ ││ ││ ││ ┌┐         
 │ ││ ││ ││ ││ ││ ││ ││ ┌┐      
 │ ││ ││ ││ ││ ││ ││ ││ ││ ┌┐   

First example above has several local optima, with no smoothness, or a close to random relationship between x and y - not very good for gradient descent.

Evolution is the obvious example of a black box optimization problem. There’s no analytical way to determine a good change of parameters (genes), so start with something, change things randomly and propagate winners. Evolution, or genetic algorithms make an assumption that parameters close within the space of parameters will have a more similar fitness than those that are far away from each other.

Without this assumption then really you don’t have any options but to blind search. A lot of cryptography in general depends on the idea that really your only approach to solving the problem is a brute force search over the parameter space, and this is effectively what you are doing when you mine for bitcoins.

What makes a good #

In the above cases, we are trying to fit a set of observations that presumably have already collected. But an important objective of machine learning is to come up with solutions that generalize well. That is, our solution / model / function approximation also does a good job when it’s presented with data that it hasn’t seen before.

Perfectly predicting the available data is often not that hard. The trivial case here is nearest neighbor search, which effectively just stores the complete mapping of observations and therefore can reproduce them with perfect accuracy given enough memory. Neural networks with enough parameters will often learn to predict the training data with 100% accuracy. In the reasoning with BERT paper, the model is able to learn the training data well, even if when presented with new data it hasn’t seen before, it fails.

There’s not always a way to know if your solution generalizes to new data points well, particularly if those data points are not some interpolation of the data that you have seen (out of distribution). When we can, we hold back some data that we don’t use to learn our approximation of the function, and then test the trained model on this without training to see how it performs.

We have some intuitions and heuristics here too, like preferring simpler models if they explain the data as well - and this isn’t a concept just constrained to machine learning, but something that is generally considered good in science itself when it comes to trying to understand the world around us - a simpler model, provided it doesn’t give up any accuracy - is generally preferred.

But overall, we cannot directly optimize for generalization, only proxies for it that work more often than not, or bake structure into our f` that encourages generalization more often than not - and that seems to be the best we can hope for.

Final thoughts #

Can we definitely create an architecture or that is capable of learning to solve and generalize to any problem through gradient descent? This isn’t obviously true at all. There clearly exist some forms of problems for which a gradient based approach will never work. So far we’ve found that usually we can create better , or train with more data to find a more general solution, but it’s useful to understand the limits of what we are doing and what the core motivation behind experimenting with new deep learning architectures ultimately is.

