Semantic representations (aka semantic embeddings) of complicated data types (e.g. images, text, video) have become central in machine learning, and also crop up in machine translation, language models, GANs, domain transfer, etc. These involve learning a representation function $f$ such that for any data point $x$ its representation $f(x)$ is “high level” (retains semantic information while discarding low level details, such as color of individual pixels in an image) and “compact” (low dimensional). The test of a good representation is that it should greatly simplify solving new classification tasks, by allowing them to be solved via linear classifiers (or other low-complexity classifiers) using small amounts of labeled data.
Researchers are most interested in unsupervised representation learning using unlabeled data. A popular approach is to use objectives similar to the word2vec algorithm for word embeddings, which work well for diverse data types such as molecules, social networks, images, text etc. See the wikipage of word2vec for references. Why do such objectives succeed in such diverse settings? This post is about an explanation for these methods by using our new theoretical framework with coauthor Misha Khodak. The framework makes minimalistic assumptions, which is a good thing, since word2vec-like algorithms apply to vastly different data types and it is unlikely that they can share a common Bayesian generative model for the data. (An example of generative models in this space is described in an earlier blog post on the RAND-WALK model.) As a bonus this framework also yields principled ways to design new variants of the training objectives.
Semantic representations learning
Do good, broadly useful representations even exist in the first place? In domains such as computer vision, we know the answer is “yes” because deep convolutional neural networks (CNNs), when trained to high accuracy on large multiclass labeled datasets such as ImageNet, end up learning very powerful and succinct representations along the way. The penultimate layer of the net — the input into the final softmax layer — serves as a good semantic embedding of the image in new unrelated visual tasks. (Other layers from the trained net can also serve as good embeddings.) In fact the availability of such embeddings from pre-trained (on large multiclass datasets) nets has led to a revolution in computer vision, allowing a host of new classification tasks to be solved with low-complexity classifiers (e.g. linear classifiers) using very little labeled data. Thus they are the gold standard to compare to if we try to learn embeddings via unlabeled data.
word2vec-like methods: CURL
Since the success of word2vec, similar approaches were used to learn embeddings for sentences and paragraphs, images and biological sequences. All of these methods share a key idea: they leverage access to pairs of similar data points $x, x^+$, and learn an embedding function $f$ such that the inner product of $f(x)$ and $f(x^+)$ is on average higher than the inner product of $f(x)$ and $f(x^-)$, where $x^{-}$ is a random data point (and thus presumably dissimilar to $x$). In practice similar data points are usually found heuristically, often using co-occurrences, e.g. consecutive sentences in a large text corpus, nearby frames in a video clip, different patches in the same image, etc.
A good example of such methods is Quick Thoughts (QT) from Logeswaran and Lee, which is the state-of-the-art unsupervised text embedding on many tasks. To learn a representation function $f$, QT minimizes the following loss function on a large text corpus
\[L_{un}(f):=\mathop{\mathbb{E}}\limits_{x,x^+,x^-}\left[-\log\left(\frac{e^{f(x)^Tf(x^+)}}{e^{f(x)^Tf(x^+)}+e^{f(x)^Tf(x^-)}}\right)\right] (1),\]where $(x, x^+)$ are consecutive sentences and presumably “semantically similar” and $x^-$ is a random negative sample. For images $x$ and $x^+$ could be nearby frames from a video. For text, two successive sentences serve as good candidates for a similar pair: for example, the following are two successive sentences in the Wikipedia page on word2vec: “High frequency words often provide little information.” and “Words with frequency above a certain threshold may be subsampled to increase training speed.” Clearly they are much more similar than a random pair of sentences, and the learner exploits this. From now on we use Contrastive Unsupervised Representation Learning (CURL) to refer to methods that leverage similar pairs of data points and our goal is to analyze these methods.
Need for a new framework
The standard framework for machine learning involves minimizing some loss function, and learning is said to succeed (or generalize) if the loss is roughly the same on the average training data point and the average test data point. In contrastive learning, however, the objective used at test time is very different from the training objective: generalization error is not the right way to think about this.
Main Hurdle for Theory: We have to show that doing well on task A (minimizing the word2vec-like objective) allows the representation to do well on task B (i.e., classification tasks revealed later).
Earlier methods along such lines include kernel learning and semi-supervised learning, but there training typically requires at least a few labeled examples from the classification tasks of future interest. Bayesian approaches using generative models are also well-established in simpler settings, but have proved difficult for complicated data such as images and text. Furthermore, the simple word2vec-like learners described above do not appear to operate like Bayesian optimizers in any obvious way, and also work for very different data types.
We tackle this problem by proposing a framework that formalizes the notion of semantic similarity that is implicitly used by these algorithms and use the framework to show why contrastive learning gives good representations, while defining what good representations mean in this context.
Our framework
Clearly, the implicit/heuristic notion of similarity used in contrastive learning is connected to the downstream tasks in some way — e.g., similarity carries a strong hint that on average the “similar pairs” tend to be assigned the same labels in many downstream tasks (though there is no hard guarantee per se). We present a simple and minimalistic framework to formalize such a notion of similarity. For purposes of exposition we’ll refer to data points as “images”.
Semantic similarity
We assume nature has many classes of images, and has a measure $\rho$ on a set of classes $\mathcal{C}$, so that if asked to pick a class it selects $c$ with probability $\rho(c)$. Each class $c$ also has an associated distribution $D_c$ on images i.e. if nature is asked to furnish examples of class $c$ (e.g., the class “dogs”) then it picks image $x$ with probability $D_c(x)$. Note that classes can have arbitrary overlap, including no overlap. To formalize a notion of semantic similarity we assume that when asked to provide “similar” images, nature picks a class $c^+$ from $\mathcal{C}$ using measure $\rho$ and then picks two i.i.d. samples $x, x^{+}$ from the distribution $D_{c^+}$. The dissimilar example $x^{-}$ is picked by selecting another class $c^-$ from measure $\rho$ and picking a random sample $x^{-}$ from $D_{c^-}$.
The training objective for learning the representation is exactly the QT objective from earlier, but now inherits the following interpretation from the framework \(\min_{f\in\mathcal{F}}\ L_{un}(f)=\mathop{\mathbb{E}}\limits_{c^+,c^-\sim\rho}\ \mathop{\mathbb{E}}\limits_{x,x^+\sim D_{c^+}}\ \mathop{\mathbb{E}}\limits_{x^-\sim D_{c^-}}\left[\log\left(1+e^{f(x)^Tf(x^-)-f(x)^Tf(x^+)}\right)\right]\)
Note that the function class $\mathcal{F}$ is an arbitrary deep net architecture mapping images to embeddings (neural net sans the final layer), and one would learn $f$ via gradient descent/back-propagation as usual. Of course, no theory currently exists for explaining when optimization succeeds for complicated deep nets, so our framework will simply assume that gradient descent has already resulted in some representation $f$ that achieves low loss, and studies how well this does in downstream classification tasks.
Testing representations
What defines a good representation? We assume that the quality of the representation is tested by using it to solve a binary (i.e., two-way) classification task using a linear classifier. (The paper also studies extensions to $k$-way classification in the downstream task.) How is this binary classification task selected? Nature picks two classes $c_1, c_2$ randomly according to measure $\rho$ and picks data points for each class according to the associated probability distributions $D_{c_1}$ and $D_{c_2}$. The representation is then used to solve this binary task via logistic regression: namely, find two vectors $w_1, w_2$ so as to minimize the following loss \(L_{sup}(f, (c_1,c_2))=\inf\limits_{w_1,w_2}\left[\frac{1}{2}\mathop{\mathbb{E}}\limits_{x\sim D_{c_1}}\log(1+e^{f(x)^Tw_2-f(x)^Tw_1})+\frac{1}{2}\mathop{\mathbb{E}}\limits_{x\sim D_{c_2}}\log(1+e^{f(x)^Tw_1-f(x)^Tw_2})\right]\)
The quality of the representation is estimated as the average loss over nature’s choices of binary classification tasks. \(L_{sup}(f)=\mathop{\mathbb{E}}\limits_{c_1,c_2\sim\rho}\ \left[L_{sup}(f, (c_1,c_2))|c_1\neq c_2\right]\)
It is important to note that the latent classes present in the unlabeled data are the same classes present in the classification tasks. This allows us to formalize a sense of ‘semantic similarity’ as alluded to above: the classes from which data points appear together more frequently are the classes that make up relevant classification tasks. Note that if the number of classes is large, then typically the data used in unsupervised training may involve no samples from the classes used at test time. Indeed, we are hoping to show that the learned representations are useful for classification on potentially unseen classes.
Provable guarantees for unsupervised learning
What would be a dream result for theory? Suppose we fix a class of representation functions ${\mathcal F}$, say those computable by a ResNet 50 architecture with some choices of layer sizes etc.
Dream Theorem: Minimizing the unsupervised loss (using modest amount of unlabeled data) yields a representation function $f \in {\mathcal F}$ that is competitive with the best representation from ${\mathcal F}$ on downstream classification tasks, even with very few labeled examples per task.
While the number of unlabeled data pairs needed to learn an approximate minimizer can be controlled using Rademacher complexity arguments (see paper), we show that the dream theorem is impossible as phrased: we can exhibit a simple class ${\mathcal F}$ where the contrastive objective does not yield representations even remotely competitive with the best in the class. This should not be surprising and only suggests that further progress towards such a dream result would require making more assumptions than the above minimalistic ones.
Instead, our paper makes progress by showing that under the above framework, if the unsupervised loss happens to be small at the end of contrastive learning then the resulting representations perform well on downstream classification.
Simple Lemma: The average classification loss on downstream binary tasks is upper bounded by the unsupervised loss. \(L_{sup}(f)\le \alpha L_{un}(f),~\forall f\in\mathcal{F}\ \ \ \ \ \ (2)\) where $\alpha$ depends on $\rho$. ($\alpha\rightarrow 1$ when $|\mathcal{C}|\rightarrow\infty$, for uniform $\rho$)
This says that the unsupervised loss function can be treated as a surrogate for the performance on downstream supervised tasks solved using linear classification, so minimizing it makes sense. Furthermore, just a few labeled examples are needed to learn the linear classifiers in future downstream tasks. Thus our minimalistic framework lets us show guarantees for contrastive learning and also highlights the labeled sample complexity benefits provided by it. For details as well as more finegrained analysis see the paper.
Extensions of the theoretical analysis
This conceptual framework not only allows us to reason about empirically successful variants of (1), but also leads to the design of new, theoretically grounded unsupervised objective functions. Here we give a high level view; details are in our paper.
A priori, one might imagine that the log and exponentials in (1) have some information-theoretic interpretation; here we relate the functional form to the fact that logistic regression is going to be used in the downstream classification tasks. Analogously, if the classification is done via hinge loss, then (2) is true for a different unsupervised loss that uses a hinge-like loss instead. This objective, for instance, was used to learn image representations from videos by Wang and Gupta. Also, usually in practice $k>1$ negative samples are contrasted with each positive sample $(x,x^+)$ and the unsupervised objective looks like the $k$-class cross-entropy loss. We prove a statement similar to (2) for this setting, where the supervised loss now is the average $(k+1)$-way classification loss.
Finally, the framework provides guidelines for designing new unsupervised objectives when blocks of similar data are available (e.g., sentences in a paragraph). Replacing $f(x^+)$ and $f(x^-)$ in (1) with the average of the representations from the positive and the negative block respectively, we get a new objective which comes with stronger guarantees and better performance in practice. We experimentally verify the effectiveness of this variant in our paper.
Experiments
We report some controlled experiments to verify the theory. Lacking a canonical multiclass problem for text, we constructed a new 3029-class labeled dataset where a class is one of 3029 articles from Wikipedia, and datapoints are one of $200$ sentences in these articles. Representations will be tested on a random binary classification task that involves two articles, where the labels of the data point is which of the two articles it belongs to. (A 10-way classification task is similarly defined.) Datapoints for the test tasks will be held out while training representations. The class of sentence representation ${\mathcal F}$ is a simple multilayer architecture one based on Gated Recurrent Unit (GRU).
The supervised method for learning representations trains a multiclass classifier on the 3029-way task and the representation is taken from the layer before the final softmax output. This was the gold standard in above discussions.
The unsupervised method is fed pairs of similar data points generated according to our theory: similar data points are just pairs of sentences sampled from the same article. Representations are learnt by minimizing the above unsupervised loss objectives.
The highlighted parts in the table show that the unsupervised representations compete well with the supervised representations on the average $k$-way classification task ($k=2, 10$).
Additionally, even though not covered by our theory, the representation also performs respectably on the full multiclass problem. We find some support for a suggestion of our theory that the mean (centroid) of the unsupervised representations in each class should be good classifiers for average $k$-way supervised tasks. We find this to be true for unsupervised representations, and surprisingly for supervised representations as well.
The paper also has other experiments studying the effect of number of negative samples and larger blocks of similar data points, including experiments on the CIFAR-100 image dataset.
Conclusions
While contrastive learning is a well-known intuitive algorithm, its practical success has been a mystery for theory. Our conceptual framework lets us formally show guarantees for representations learnt using such algorithms. While shedding light on such algorithms, the framework also lets us come up with and analyze variants of it. It also provides insights into what guarantees are provable and shapes the search for new assumptions that would allow stronger guarantees. While this is a first cut, possible extensions include imposing a metric structure among the latent classes. Connections to meta-learning and transfer learning may also arise. We hope that this framework influences and guides practical implementations in the future.