This past December, I wanted to really deep-dive into machine learning in order to develop more solid intuition regarding various different model architectures, starting from the simplest multi-layer perceptrons all the way up to transformers. Along the way, I came across the idea of representation learning, which was really helpful in offering a helpful visualisation of what models do. This post is an attempt to create something that's easy to point at as a distillation of those ideas of manifold manipulation and geometric interpretation of feature learning and reshaping.
First and foremost, this post does not seek to replicate the myriad of resources regarding neural networks that already exist out in the world; instead, the hope is that for people with a simple foundation (assuming basic ML knowledge so as to not need to redefine e.g. neurons, weights, biases, etc.), thinking through the lens of manifolds can offer a more intuitive, visual way of thinking about the empirics of neural networks.
A foray into the fundamentals
In almost all neural networks today, there exists a composition of different layers arranged in a feed-forward computation graph that takes in inputs and performs a sequential set of combinations and activations on these inputs. In other words, for each input, you can think of each layer as a specific transformation (in the geometric sense) that somehow modifies the geometric relationships within the high-dimensional space defined by the set of input features. The whole network, then, is essentially a composition of these different transformations that seek to create a specific shape of geometric relationships.
We can examine a simple geometric example to see this principle in action. Specifically, we can consider the idea of classification, which can be reduced to some form of function that approximates the decision boundary between different classes.
Intuitively, the simplest function to draw, especially for classification decision boundaries, is a line. However, not all data is linearly separable (as in the above figure), so to make everything roughly linearly separable, we must employ complicated manipulations in order to make the data roughly linearly representable.
A transformative paradigm shift
One intuitive question to ask when thinking about the above idea of a neural network as a composition of transformations is as follows:
After the composition of transformations (i.e. at the output layer), what does the geometry of the input space look like?
This is indeed the key idea of representation learning and manifold manipulation: reforming and reshaping the axes that the data lies on in order to make the shape of the dataset more amenable to simple operations like linear separation. Specifically, consider what would happen if we passed every single point in the initial input data space through the first layer's transformation operation: the input data space would itself be transformed into a new shape, with different geometries and geometric relationships as compared to the initial high-dimensional space (it might even have different dimensions, depending on how many units we have in the particular layer). As we do this multiple times across each layer, we'll continue creating more levels in the hierarchy of representations of the original dataset, where at each step we'll get closer to a simple and linearly separable set of data.
The more nuanced manifold learning hypothesis
Though this, in and of itself, is already the intuition behind manifold learning, we can add one more nuance to the shape of the data in order to allow for the development of further implications of this geometric intuition to ML.
Specifically, the full representation learning hypothesis postulates that the input training data usually lies on (or around) a low-dimensional manifold (in the formal mathematical sense) within the full input space, termed the "ambient space". Combined with our intuition above, this picture makes it a little bit clearer what the goal of composing the different transformations at each layer is: we simply want to compose transformations that reshape the data manifold in such a way that any relevant, predictive trends can be represented with simple functions at the output layer. In that sense, most ML models are literal "representation learners"; that is, they must learn how to create transformations that most clearly represent the inherent trends in the data in a way that's most simply comprehensible for the output layers.
Implications of manifold manipulation/case studies
In thinking geometrically about layer operations within neural networks, we thus come across a number of different implications:
- Geometric interpretations of tasks. A composition of transformations can allow for simple linear (or single-function) representations of boundaries between different classes, as discussed above. However, transformations of different datasets (like time series, sequences, etc.) can also produce powerful models of different dynamical systems or trajectories; the natural example is in autoregressive text prediction (LLMs today), where a set of 50-60k tokens can create an infinitely diverse set of dynamical trajectories that allow for the vast capabilities that we see today.
- Dynamical and state-space relationships. In an upcoming post, I'll talk more about the idea of action-consequence and the idea of recurrence in learning as well as how to reason and plan about that, but understanding geometric/dynamical systems representations of these ideas can create more powerful state space representations and RNN/S4-type models that might be able to use more general primitives and more effectively continually learn based on consequences.
- First-principles ML. Through an understanding of the actual reshaping and manifold that the model creates, it is a possibility that we'll be able to better understand (and even inject) inductive biases into the dataset and the model's transformational capabilities instead of relying on pure empirics and heuristics. This directly leads to the idea of producing more on-policy models as well as models that might be more immediately mechanistically interpretable.