- Published on
(Research Idea) Discrete Diffusion for Inverse Protein Folding
πLatek PDF
Introduction
Deep learning made a huge impression on the science community when Deepmind released AlphaFold, a deep network that predicts a protein's structure from its sequence. AlphaFold outperformed all other protein folding methods and stands as a baseline to be beaten. The converse Inverse Protein Folding problem, predicting a protein's sequence from its structure, is possibly more challenging as there are many sequences that can fold to a particular protein structure. Both science based methods and deep learning methods have been proposed for inverse protein folding. One particularly promising and quite novel approach is the use of denoising diffusion models.
Background: (denoising) diffusion models
For a more extensive treatment, Lilian Weng has a great introductory post on diffusion models (Weng et al. 2021). Calvin Luo also provides great intuition and part of his insights are included below (Luo et al. 2022).
In short a diffusion model (more precisely, denoising diffusion probabilistic model) is a parameterized Markov chain that learns to generate samples that match the input data. The diffusion model has a forward process which takes in a true data sample and progressively corrupts it over many timesteps it with noise ; notationally the data undergoes many forward transitions , which for continuous data is commonly an isotropic gaussian (here defines the variance and can either be fixed as hyper-parameters or learned by reparameterization). After many forward transitions, the original data ends up being corrupted to just noise.
The diffusion model learns a reverse process to undo this forward noising process; in this way, it can take a sample drawn from noise and denoise it to a sample that lies in the original data distribution . For example, if the forward transition is gaussian, our diffusion model would learn the mean and covariance of a reverse gaussian transition, i.e Once we learn the reverse process, we can take a randomly sampled noise and then apply the learned reverse transitions many times to get to a which hopefully lies in the true data distribution .
How do we learn the reverse process? Ultimately the distribution that results from the reverse process should match the distribution of the real data. An intuitive objective is to maximize the likelihood of the real data under the learned reverse distribution. Because maximizing the likelihood directly is intractable, we maximize a variational lower bound of the log likelihood. We can make this a cost function by taking the negative and instead minimize the negative of the lower bound of the log likelihood which spelled out is:
With enough timesteps, the forward process corrupts the data such that is essentially noise; thus the first term, the KL divergence between and the assumed prior of noise, is basically zero. The last term represents a reconstruction loss where we try to maximize the likelihood of predicting the original data given the first step latent. With the middle term, we see that at each time step we try to match our predicted transition with the actual reverse transition . (Note: we also condition on to make this calculation tractable, as explained in (Sohl-Dickstein 2015)).
To maximize this lowerbound (or minimize the cost function), we would first parametrize the transition to propely construct our estimate reverse transition . In the gaussian setting defined in (Ho et al. 2020), ends up being parameterized as an isotropic gaussian . Based on this parameterization of the reverse transition, we try to learn an isotropic gaussian to match this transition and minimizes the KL divergence of the distributions. It turns out that our objective reduces to
i.e we try to predict the reverse process mean at some time step .
(Ho et al. 2020) shows that it is effective to train a model (i.e a neural network) to learn the added noise rather than the mean and covariance of the diffused data. Thus they used the simplified objective function: .
Why does this simplified objective make sense?
In the setting of (Ho et al. 2020), the noise is drawn from a Gaussian that is independent of both the time and the data . So we might wonder why we are trying to predict noise as a function of and . While this noise prediction objective may seem counter-intuitive, in fact it derives from the objective of predicting the forward process mean at some time step (which is a function of and ). With reparameterization and some simplifications (spelled out in (Ho et al. 2020)), the forward process mean prediction objective becomes the noise prediction objective.
Inverse protein folding
The problem of inverse protein folding is to predict a protein's sequence from its structure. Intuitively, our prediction model should have the property of roto-translational invariance. Given some protein structure, even if we apply a bunch of translations or rotations to it, it should still map to the same sequence that defines it.
How do we have roto-translation invariance?
One strategy is to take our protein structure and encode it into an roto-translation invariant representation. If is just the 3D coordinates of each residue then it is not rotation or translation invariant- but if we instead use geometric variables such as distances or angles we can get an invariant representation of the structure. For example, (Wu et al. 2022) use "inter-residue angles" or angles that represent the position of the next residue relative to the current residue to get a invariant representation. Additionally, instead of manually crafting an invariant representation, we can use a learned network to extract an invariant latent representation. For example, (Hsu et al. 2022) use a geometric vector perceptron graph neural network (GVP-GNN) encoder-decoder to get a geometric invariant latent vector from the protein structure.
Difussion models for inverse protein folding
Learning to generate sequences conditioned on structure data
We know that a diffusion model can learn how generate realistic samples that approximately lie in the real data distribution. Specifically, our diffusion model (which is just a neural network that predicts the reverse process) can be conditioned on some information, specifically the protein structure. Then, we can simply pass some protein sequence sampled from noise to the reverse diffusion process while conditioning on the protein structure to generate a realistic protein sequence corresponding to the real protein structure.
To unpack this, say we have some real protein sequence and structure data, where is the protein sequence and is the structure. A conditional diffusion model that diffuses the protein sequence minimizes the objective function (here we use the simplified noise objective function)
where is the real data protein structure and is the noise on the sequence data. We see that our diffusion model is learning the noise conditional on the current noisy sequence and the actual protein structure (and the timestep ). To make this explicit, under the squared objective we learn the mean .
How can we implement this in practice?
I was initially confused on how to learn this conditional distribution with multiple inputs in practice. But it's no different than the toy example of predicting house prices conditional on some room sizes and house characteristics; in that example, we could concatenate house sizes with house characteristics to create our input and then simply pass that to the model. Similarly, we could simply concenate the two conditioning inputs and : for example if we used a transformer we would pass in a sequence whose first chunk is the sequence information and the second chunk is the structure information and then predict noise with a regression head.
Conditioning on information has been shown to possibly lead to better diffusion performance; this is the central thesis of Classifier Free Guidance introduced by (Ho et al. 2022). The idea is that the diffusion model learns a linear combination of both a conditional and unconditional noise estimate
where is the conditioning information and toggles how much to incorporate the conditional information.
Interestingly, (Ho et al. 2022) learn both the conditional and unconditional noise estimate using the a single neural network for simplicity. To learn the unconditional noise estimate, they simply input a null token as the conditional input so the network really is only conditioned on the information .
Learning to generate sequence and structures jointly
Instead of learning to generate realistic protein sequences conditioned on its structure information, we can instead learn to generate realistic protein sequence and structure pairs. Now, instead of just reverse diffusing or denoising a noisy sampled sequence, we can reverse diffuse a noisy protein sequence and structure pair to get a realistic sequence and structure pair. Under our objective we learn the noise where we denote as the noise on the sequence data and as the noise on the structure data (Bao et al. 2023). Contrary to before, instead of conditioning on the intial true protein structure, we are conditioning on the noisy sampled structure corresponding to that timestep.
Speeding up with a latent vector (latent diffusion)
One problem with diffusion models is that the forward and reverse process can be very time and compute expensive. Instead of diffusing directly on the input data, the latent diffusion approach is to compress the input into a smaller latent representation using an autoencoder (Rombach et al. 2022). Diffusion is run on the latent representation (i.e we forward diffuse the latents, learn the reverse diffusion process on the latents, then sample a random noisy latent to reverse diffuse into a realistic latent that lies in the latent data distribution) and then decode the generated latent with the autoencoder to generate a realistic original input sample.
Diffusion with non-continuous data (e.g protein sequence)
In learning to generate a protein sequence conditioned on the structure or in learning to generate a protein sequence and structure pair, diffusion needs to forward diffuse or noise a sequence. But a protein sequence is a sequence of discrete variables, where each variable takes one of twenty protein residue categories. We can't just simply add gaussian noise to our discrete variables. One solution is to force the categorical protein residue to be embedded as a continuous variable, but this could be an unnatural approach. Rather, we can have our transitions be tailored for discrete variables. This is the idea of Discrete Denoising Diffusion Probabilitistic Models, introduced by (Austin et al. 2021).
One example of a transition function for discrete spaces is the uniform transition: it stays in the same state with a certain probability and then the remaining probability mass is uniformly distributed to all possible states. Another example is a discretized gaussian, where the gaussian function is truncated and discretized such that it mimics the continuous gaussian transition. To define our transition function more generally, we can consider a transition matrix where the entry defines the transition from state value to state value . Notationally, our categorical variable is a one hot row vector (i.e 1 for the corresponding protein residue, 0 elsewhere) and we can define our transition function as a categorical distribution defined by where is our transition matrix. For our markov chain to be well-defined, we need to define such that for large the distribution defined by converges to a stationary distribution.
As far as I know, there is no heuristic that determines which transition functions would allow the diffusion model to perform better generation. In the case of generating images, (Austin 2021) show that the discretized gaussian performs best relative to the uniform transition and also an absorbing transition (defined later below).
Discrete diffusion objective function
In the case of discrete diffusion, it is clearer to go back to the more complete spelled out varational log likelihood objective.
(Austin et al. 2021) use the loss function that is a sum of the variational likelihood objective and a cross entropy term or in full
To deal with the main term, the divergence between our predicted reverse process and the reverse process , we need to parametrize our predicted reverse transition . (Austin 2021) parameterize as where we sum over all possible output sequences .
(Note: the predicted function is denoted as rather than to show that is a neural network that predict the distribution of a final diffused output given . defined above incorporates to predict the diffused output one step away.)
The real kicker: masked language models as discrete diffusion models
Masked language models, like BERT, learn to generate text in a sequence based on a given context. Specifically, in training, input sequences are randomly corrupted (i.e some tokens become [MASK]) and the language model tries to predict and recover the masked tokens given the unmasked context. (Austin et al. 2021) in fact prove that masked language models under a cross entropy objective are discrete diffusion models. Specifically, they show that masked language models are an absorbing-state discrete diffusion model. In an absorbing state discrete diffusion model, each token either stays in the same state or transitions to an absorbing state (e.g [MASK]) to get stuck there forever. If we say that the probability of transitioning to the absorbing state is then we can define the transition probability by the matrix entries
Our transition matrix is such that for large, converges to the stationary distribution where all the probability mass is on the absorbing state [MASK].
Implement masked language models (MLM) for inverse protein folding
To train the MLM diffusion model, (Anand et al. 2022) randomly mask a fraction of the protein sequence and have the model predict the masked parts of sequence. The mask fraction is linearly interpolated from [0,1] and so the language model learns to predict from more and more noise. Then, to generate a predicted sequence, it starts with a completely masked sequence (the equivalent of pure noise). Here we are sampling our noisy sequence from the stationary distribution that results from the forward diffusion process, which has all probability mass on the [MASK] state. To progressively denoise and reverse diffuse our sequence, at each round they mask the generated denoised with a probability and predict the masked protein residues to generate the next , repeating the process again to get back to a generated protein sequence.
Order-agnostic autoregressive models (OARM) as discrete diffusion models
It turns out the other kind of models typical in language models, autoregressive models, can also be defined as a discrete diffusion model. (Hoogeboom et al. 2022) introduce order-agnostic autoregressive diffusion models, which is derived from autoregressive models and can learn to generate in any order.
The typical autoregressive likelihood objective is precisely where is the length of the sequence. In the autoregressive setting, earlier parts of the sequence contain information that is used to predict the later elements. However, in cases where there is not a pre-defined or natural ordering of the sequence, we can instead use an order-agnostic autoregressive objective, where we first randomly permute the sequence before prediction. Our likelihood objective has the following lowerbound
where the inequality is due to Jensen's inequality with a concave function. We can further re-express our lowerbound as
This can be expressed as (full details in Hoogeboom et al. 2022)
If we notate our likelihood lower bound objective is: .
We see in the term that the first terms in some randomly permuted sequence are used the predict the rest. The way that we optimize in practice is with a BERT objective where we mask the tail elements of the permuted sequence and predict them using the head unmasked elements.
Training our model in practice
To optimize our likelihood bound, we need to predict the probability distribution of the masked elements. To do this we train a model taht predicts a probability distribution over each element in the sequence. Notationally, let be our sequence of categories and let our single neural network be a model that outputs a probability distribution over the categories for each element in the sequence. To predict the probability distributions of all of the masked tokens in the sequence, given a mask that corrupts certain tokens to [MASK] and leaves the rest unchanged, we pass in our masked sequence to our network . Then, we can extract from the output the predicted probability vectors for the mask tokens.
Sampling our OARM model to generate sequences
Similar to the absorbing diffusion model case, to generate a predicted output sequence we first start with an all masked sequence and then use our model to predict the distribution of elements from which we can sample the next sequence. At each timestep as goes from we mask the sequence less and less; whereas in (Anand et al. 2022) the probability of masking decreased, here we mask the tail tokens.
To write the sampling process explicitly:
Sample some random permutation , start with an all masked sequence
for in :
Create our mask which masks everything but the first tokens in the permuted sequence
Sample an which is drawn from the predicted probability distribution outputed by our trained network by inputting our masked sequence
Combine the non-masked parts from or and the predicted masked parts from to generate our next timestep
Proposed method: diffusion sequence impainting
The inverse folding problem is to predict a protein sequence given a protein structure; thus, a natural approach may directly condition the diffusion model on the protein structure to generate a protein sequence. For example, (Anand et al. 2022) do this with an absorbing discrete diffusion model and (Yi et al. 2023) do this with a graph diffusion model.
However instead of directly conditioning the diffusion model on the protein structure, we can follow the Repaint image impainting framework by (Lugmayr et al. 2022) and condition on noised protein structure data. To provide some intuition, in the image impainting problem part of the image is missing/unknown and we want to use the known parts of the image to impaint or fill in the missing region. In our setting, we can image that a protein sequence and structure pair makes up an image and the protein sequence is the missing information that we want to "impaint". One way to do this as mentioned before is to directly condition on the known data to predict our missing data.
Alternatively we can imagine noising the entire image and learn to reverse diffuse the entire image while conditioning on noisy known data to ensure that our final generated image contains the correct known region and an impainted unknown region that corresponds to the known region. (We can imagine diffusion on the entire image as a joint diffusion on the known and unknown regions, where the noise are identifical in distribution). In our setting, we would use noised protein structure data to condition our diffusion model to generate a corresponding protein sequence.
Repaint framework
Repaint differs from normal diffusion on the entire image as it conditions on a preset noisy version of the known region and it also employs a resampling step. Notationally, given some ground truth image , unknown pixels , and known pixels , they define the one-step reverse diffused output to be a combination of the unknown pixel region of the reversed diffused image (as in typical diffusion) and the known pixel region of a preset corrupted ground truth image .
What (Lugmayr et al. 2022) remark is that the final predicted impainted region doesn't harmonize well with the rest of the image even though it seems to be integrating the conditioning known data. Thus they use a resample step with some jump size , where the reverse diffusion output is re-noised times to by undergoing a series of noise steps defined by .
Inpainting for protein sequence recovery
We can adapt the inpainting method to our setting by considering the joint distribution of the protein structureand the protein sequence and we wish to impaint the missing protein sequence data.
For simplicity, we first consider the continuous diffusion setting. We follow (Mcpartlon 2023) and use an auto-encoder to extract a continuous latent which we can run diffusion on. Specifically, we can train an autoencoder to encode our graph structure to some latent vector and we can train another autoencoder to encode our protein sequence . (Here we can consider using a similar autoencoder such that the latent representation lies in the same dataspace). We then follow (Lugmayr et al. 2022) and train a diffusion model on the latent data using gaussian noise and the typical diffusion objective function.Finally, we can recover an impainted latent sequence which we can decode using our sequence auto-encoder to generate our predicted protein sequence.
However we might consider it unnatural to encode a discrete sequence into a continuous latent. Thus we can use a discrete diffusion model for our protein sequence generation. For example, as in (Yi et al. 2023), we can use a discrete diffusion model based on a doubly stochastic transition matrix derived from empirical sequence mutation data from BLOSUM. This ensures our transition matrix is grounded in empirical data. For the protein structure data which is naturally continuous, we can either use some hand-crafted geometric representation such as (Wu et al. 2022) or follow (Hsu et al. 2022) and use a latent representation that comes from the (Jing et al. 2021) graph neural network GVP-GNN to get a geometrically invariant representation of the protein structure.
For the diffusion process we follow the Repaint method described above. The continuous protein structure data follows the typical gaussian setting:
To get our reverse diffused sequence data , we learn the reverse process , which following (Austin et al. 2021) we parametrize as .
Thus we train a network to predict the distribution of the final denoised sequence output conditioned on the previous noisy output, which we note is , the concatenation of the noised protein structure data and the denoised sequence data . (Note: we keep the notation of (Lugmayr 2022)- here is the entire noisy protein structure data and is the entire noisy protein sequence data). The training objective is the lower likelihood bound objective for discrete diffusion models stated previously.
To generate our predicted impainted protein sequence, we would sample a pure noise sequence from a uniform distribution and then run the reverse diffusion process with our trained diffusion model. (Here we sample from the uniform distribution because a doubly stochastic trasition matrix converges to a uniform stationary distribution.)
Choice of discrete diffusion model
So far we have seen an absorbing state discrete diffusion model, a more general discrete diffusion model defined by some transition matrix , and also a order-agnostic autoregressive discrete diffusion model. (Anand et al. 2022) use an absorbing discrete diffusion model, (Yi et al. 2023) use a discrete diffusion model where the transition matrix comes from BLOSUM. (Alamdari et al. 2023) compare the performance of a order-agnostic autoregressive discrete diffusion model, a BLOSUM based discrete diffusion model, and an absorbing state discrete diffusion model and state that the order-agnostic autoregressive model performs best for sequence generation for protein design. Following this, we might decide to use instead an order-agnostic autoregressive model.
Other ideas: joint distribution and following classifer free guidance
The inpainting method by (Lugmayr et al. 2022) can be seen as an intermediate method between a direct conditioning diffusion model where we learn the noise on the sequence conditioned on the true protein structure and a joint distribution diffusion model where we learn both the noise on the sequence and the noise on the structure.
We could consider the joint diffusion setting where we do not explicitly condition on the structure information and instead jointly reverse diffuse a pure noise sequence and structure sample to generate a realistic sequence and structure pair. However we would need to construct some way of enforcing the generated protein structure to match the true input protein structure for our inverse folding setting.
(Ho et al. 2022) in their work on classifier free guidance show better results when combining both a condition and unconditional model. Following this, we might consider learning both a conditional and unconditional discrete diffusion model to generate sequences . For example, our unconditional model parameterized by weights can learn the reverse process by training a network to predict the distribution of the final denoised sequence simply based off the previous denoised sequence, i.e For the conditional model parameterized by weights , we can encode our graph structure information into some geometric invariant representation and then learn the reverse diffusion process by training a network to predict the distribution of the final denoised sequence given the previous denoised sequence and the true protein structure, i.e
To generate our sequences, we can linearly combine the predicted distributions of the two models by some weighting parameter to get the predicted distribution to use in the reverse diffusion process and generate our predicted sequences.
Data augmentation ideas
The current lack of sequence and structure pair data currently bottle necks inverse protein folding models. Following (Hsu et al. 2022), we can use AlphaFold2 to predict the protein structures from a dataset such as UniRef50 sequences to augment our protein sequence and structure dataset. We may consider adding a pertubed AlphaFold structures as (Dauparas et al. 2022) show that adding a small gaussian noise to the AlphaFold protein structures improves performance. We can also consider noisy student self distillation. Following (Jumper et al. 2021) we can take protein structures lacking sequence info, predict sequences using our inverse folding model, and then add our structure and predicted sequence data to augment our dataset and retrain our model.
References
(Luo 2022) Understanding Diffusion Models: A Unified Perspective
(Weng 2021) (What are diffusion models?)
(Sohl-Dickstein 2015) Deep Unsupervised Learning using Nonequilibrium Thermodynamics
(Ho 2020) Denoising Diffusion Probabilistic Models
(Wu 2022) Protein structure generation via folding diffusion
(Hsu 2022) Learning inverse folding from millions of predicted structures
(Ho 2022) Classifier-Free Diffusion Guidance
(Bao 2023) One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale
(Rombach 2022) High-Resolution Image Synthesis with Latent Diffusion Models
(Austin 2021) Structured Denoising Diffusion Models in Discrete State-Spaces
(Anand 2022) Protein Structure and Sequence Generation with Equivariant Denoising Diffusion Probabilistic Models
(Hoogeboom 2022) Autoregressive Diffusion Models
(Yi 2023) Graph Denoising Diffusion for Inverse Protein Folding
(Lugmayr 2022) RePaint: Inpainting using Denoising Diffusion Probabilistic Models
(Jing 2021) Learning from Protein Structure with Geometric Vector Perceptrons
(Alamdari 2023) Protein generation with evolutionary diffusion: sequence is all you need
(Dauparas 2022) Robust deep learning--based protein sequence design using ProteinMPNN
(Jumper 2021) Highly accurate protein structure prediction with AlphaFold