Distillation of MSA Embeddings to Folded Protein Structures using Graph Transformers
20220392566 · 2022-12-08
Assignee
Inventors
- Pranam Chatterjee (Cambridge, MA, US)
- Allan S Costa (Boston, MA, US)
- Joseph M. Jacobson (Newton, MA)
- Raghava Manvith Ponnapati (Cambridge, MA, US)
Cpc classification
International classification
Abstract
An attention-based graph architecture that exploits MSA Transformer embeddings to directly produce models of three-dimensional folded structures from protein sequences includes a method and system for augmenting the protein sequence to obtain multiple sequence alignments, producing enriched individual and pairwise embeddings from the multiple sequence alignments using an MSA-Transformer, extracting relevant features and structure latent states from the enriched individual and pairwise embeddings for use by a downstream graph transformer, assigning individual and pairwise embeddings to nodes and edges, respectively, using the downstream graph transformer to operate on node representations through an attention-based mechanism that considers pairwise edge attributes to obtain final node encodings, and projecting the final node encodings to form the computer-modeled folded protein structure. An induced distogram of the computer-modeled folded protein structure may be computed.
Claims
1. A method for computer modelling of a three-dimensional folded protein structure based on a protein sequence, comprising: using a computer processor, performing the steps of: augmenting the protein sequence to obtain multiple sequence alignments; using an MSA-Transformer, producing enriched individual and pairwise embeddings from the multiple sequence alignments; extracting, from the enriched individual and pairwise embeddings, relevant features and structure latent states for use by a downstream graph transformer; assigning individual and pairwise embeddings to nodes and edges, respectively; using the downstream graph transformer, operating on node representations through an attention-based mechanism that considers pairwise edge attributes to obtain final node encodings; and projecting the final node encodings to form the computer-modeled folded protein structure.
2. The method of claim 1, further comprising computing an induced distogram of the computer-modeled folded protein structure.
3. The method of claim 1, further comprising storing any individual and pairwise embeddings that are from the original protein sequence.
4. A method for folding a protein sequence in silico using an attention-based graph transformer architecture, comprising: using the MSA transformer, producing information-dense embeddings from the protein sequence; from the embeddings, producing initial node and edge hidden representations in a complete graph; using the attention-based graph transformer architecture, processing and structuring geometric information, to obtain final node representations; and projecting the final node representations into Cartesian coordinates through a learnable transformation to obtain the folded protein sequence.
5. The method of claim 4, further comprising calculating induced distance maps from the projected final node representations.
6. The method of claim 5, further comprising comparing the induced distance maps to ground truth counterparts in order to define the loss.
7. A system for producing models of three-dimensional folded protein structures from protein sequences, comprising a computer processor or set of processors specially adapted for performing the steps of: augmenting a protein sequence to obtain multiple sequence alignments; using an MSA-Transformer, producing enriched individual and pairwise embeddings from the multiple sequence alignments; extracting, from the enriched individual and pairwise embeddings, relevant features and structure latent states for use by a downstream graph transformer; assigning individual and pairwise embeddings to nodes and edges, respectively; using the downstream graph transformer, operating on node representations through an attention-based mechanism that considers pairwise edge attributes to obtain final node encodings; and projecting the final node encodings to form a model three-dimensional folded protein structure.
8. The system of claim 7, wherein the computer processor or set of processors is further specially adapted for performing the step of computing an induced distogram of the computer-modeled folded protein structure.
Description
BRIEF DESCRIPTION OF THE DRAWINGS
[0012] Other aspects, advantages and novel features of the invention will become more apparent from the following detailed description of the invention when considered in conjunction with the accompanying drawings, wherein:
[0013]
[0014]
[0015]
DETAILED DESCRIPTION
[0016] In the present invention, the protein folding problem is treated as a graph optimization problem. Information-dense embeddings produced by the MSA Transformer [Rao, R., Liu, J., Verkuil, R., Meier, J., Canny, J. F., Abbeel, P., Sercu, T., and Rives, A., “MSA transformer”, International Conference on Machine Learning, pp. 8844-8856. PMLR, 2021] are harvested and then used to produce initial node and edge hidden representations in a complete graph. To process and structure geometric information, the attention-based architecture of the Graph Transformer is employed, as proposed by Shi et al. [Shi, Y., Huang, Z., Feng, S., Zhong, H., Wang, W., and Sun, Y., “Masked label prediction: Unified message passing model for semi-supervised classification”, arXiv preprint arXiv:2009.03509, 2021]. Final node representations are then projected into Cartesian coordinates through a learnable transformation, and the resulting induced distance maps are compared to their ground truth counterparts in order to define the loss for training.
[0017] MSA Transformer Data Augmentation
[0018] The MSA Transformer is an unsupervised protein language model that produces information-rich residue embeddings [Rao, R., Liu, J., Verkuil, R., Meier, J., Canny, J. F., Abbeel, P., Sercu, T., and Rives, A., “MSA transformer”, International Conference on Machine Learning, pp. 8844-8856. PMLR, 2021]. In contrast to other protein language models, it operates on two dimensional inputs consisting of a length-N query sequence along with its MSA sequences. It utilizes an Axial Transformer [Ho, J., Kalchbrenner, N., Weissenborn, D., and Salimans, T., “Axial attention in multidimensional transformers”, arXiv preprint arXiv:1912.12180, 2019] as an efficient attention-based architecture for performing computation on its layers' O(N.Math.S) representations, where S is the total number of input MSA sequences.
[0019] In a preferred embodiment, the present invention operates on graph features distilled from MSA Transformer encodings. Last-layer residue embeddings capture individual and contextual residue properties. Similarly, the vector formed by pairwise attention scores at each layer and head captures attentive interactions between residue pairs. The richness of information present at these vectors has been previously demonstrated in state-of-the-art contact prediction [Rao, R., Liu, J., Verkuil, R., Meier, J., Canny, J. F., Abbeel, P., Sercu, T., and Rives, A., “MSA transformer”, International Conference on Machine Learning, pp. 8844-8856. PMLR, 2021]. The present invention extends those individual and pairwise embeddings to node and edge representations, demonstrating that learning over the resulting graph can resolve a protein's three-dimensional structure.
[0020] One particular implementation of the invention employs the 100 million parameter-sized ESM-MSA-1 model [Rao, R., Liu, J., Verkuil, R., Meier, J., Canny, J. F., Abbeel, P., Sercu, T., and Rives, A., “MSA transformer”, International Conference on Machine Learning, pp. 8844-8856. PMLR, 2021], which was trained on 26 million MSAs queried from UniRef50 and sourced from UniClust30. ESM-MSA-1 produces N residue embeddings, h.sub.i*∈.sup.768, and N×N attention score traces, h.sub.ij*∈
.sup.144, for each input sequence. Since the MSA Transformer is computationally expensive to evaluate for large S, even in the context of inference, the encodings were precomputed and made readily available for training. This implementation uses S=64, stored residue embeddings {h.sub.i*}, and attention score traces, {h.sub.ij*}.sup.j>i for each query sequence.
[0021] For training and validation, the ESM Structural Split [Rives, A., Meier, J., Sercu, T., Goyal, S., Lin, Z., Liu, J., Guo, D., Ott, M., Zitnick, C. L., Ma, J., and Fergus, R., “Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences”, Proceedings of the National Academy of Sciences, 118(15):e2016239118, 2021] was used, which builds upon trRosetta's training dataset [Yang, J., Anishchenko, I., Park, H., Peng, Z., Ovchinnikov, S., and Baker, D., “Improved protein structure prediction using predicted interresidue orientations”, Proceedings of the National Academy of Sciences, 117(3):1496-1503, 2020]. To overcome the bottleneck associated with reading large encodings directly from the file system, the splits were fixed to the first superfamily split, as specified in Rives, et al., and its MSA Transformer encodings were serialized into tar shards. A virtual layer of data shuffling was added through the WebDataset framework [Aizman, A., Maltby, G., and Breuel, T., “High performance i/o for large scale deep learning”, IEEE International Conference on Big Data (Big Data), 5965-5967, 2019]. The resulting dataset of graph features has 0.25 TB.
[0022] .sup.3 180, and the induced distogram 185 is computed for the loss.
[0023] Graph Building
[0024] In a preferred embodiment, a protein is treated as an attributed complete graph. H.sub.V and H.sub.E are the dimensionalities of node and edge representations, respectively. These attributes are extracted from MSA-Transformer embeddings through standard deep neural networks:
h.sub.i=σ(W.sub.E.sup.(D.sup.
h.sub.ij=σ(W.sub.D.sup.(D.sup.
where h.sub.i∈.sup.H.sup.
.sup.H.sup.
[0025] Graph Transformer
[0026] The Graph Transformer used in the preferred embodiment was introduced in Shi et al. [Shi, Y., Huang, Z., Feng, S., Zhong, H., Wang, W., and Sun, Y., “Masked label prediction: Unified message passing model for semi-supervised classification”, arXiv preprint arXiv:2009.03509, 2021] in order to incorporate edge features directly into graph attention. This is possible by directly summing transformations of edge attributes to the original keys and values of the attention mechanism. The present invention approaches protein folding with a variation of this architecture. Considering layer l node hidden states, {h.sub.i.sup.l}, and similarly learned edge latent states, {e.sub.ij}, if C attention heads are employed, a layer update can be written as
where ⊕ denotes concatenation, and W.sub.A.sup.(l) and W.sub.R.sup.(l) are learnable projections. As in the original architecture, batch normalization is applied to each layer. The attention scores α.sub.ij.sup.(l,c), node values v.sub.j.sup.(l,c) and edge values e.sub.ij.sup.(c) are obtained from learnable transformations of the original node hidden states and edge attributes:
q.sub.i.sup.(l,c)=W.sub.q.sup.(l,c)h.sub.i.sup.(l) k.sub.i.sup.(l,c)=W.sub.k.sup.(l,c)h.sub.i.sup.(l)
v.sub.i.sup.(l,c)=W.sub.v.sup.(l,c)h.sub.i.sup.(l) e.sub.ij.sup.(c)=W.sub.e.sup.(c)h.sub.ij
The attention scores are normalized according to graph attention:
To hold computational costs roughly constant, {q.sub.i.sup.c, v.sub.i.sup.c, k.sub.i.sup.c, e.sub.ij.sup.c}∈.sup.H.sup.
[0027] Cartesian Projection and Loss
[0028] In a preferred embodiment, a predictor is trained to recover coordinates of each residue in a learned canonical pose:
X.sub.i=W.sub.Xh.sub.i.sup.(L)
where X.sub.i∈.sup.3. To train the network, a distogram-based loss function is used on the resulting distance map. {circumflex over (D)}.sub.ij=∥X.sub.i−X.sub.j∥.sub.2 is the induced Euclidean distance between the Cartesian projections of nodes i and j, and D.sub.ij is the ground truth distance. The loss is based on the L.sub.1-norm of the difference between those values:
[0029]
[0030] Model Training
[0031] To optimize the trained model, a shallow random hyperparameter search for H.sub.V∈{32, 64, 128, 256}, H.sub.E∈{32, 64, 128}, L∈{3, 6, 10, 15}, C∈{1,2,4} was performed. The Adam Optimizer was utilized, with lr∈{1×10.sup.−3, 3×10.sup.−4, 1×10.sup.−4, 3×10.sup.−5, 1×10.sup.−5}. Variations of the loss function were also tested, testing the MSE loss and weighted versions of L.sub.1 and MSE for batch sizes B∈{10, 15, 30}.
[0032] To handle GPU memory constraints, gradient checkpointing was employed at each Graph Transformer layer. Models were trained in parallel on NVIDIA V100s provided by the MIT SuperCloud HPC [Reuther, A., Kepner, J., Byun, C., Samsi, S., Arcand, W., Bestor, D., Bergeron, B., Gadepally, V., Houle, M., Hubbell, M., et al., “Interactive supercomputing on 40,000 cores for machine learning and data analysis”, 2018 IEEE High Performance extreme Computing Conference (HPEC), pages 1-6, 2018].
[0033] In total, 40 search training runs were performed, with a maximum of 70 epochs and an early stop with a patience of 3 for validation loss. The best model trained for 17 hours without registering early stopping. With H.sub.V=H.sub.E=64, L=10, and C=1, this model only possesses a total of 382K parameters. Using lr=3×10.sup.4 and B=30, as well as an L.sub.1 loss, .sub.val=2.25 and GDT_TS.sub.val=40.58 was achieved.
[0034] CASP13 Evaluation
[0035] To investigate the generalization of the model of the invention, it was evaluated on the free modeling targets from the 13th edition of the Critical Assessment of Protein Structure Prediction (CASP13). The model was benchmarked against the performance of the current state-of-the-art public architecture: trRosetta [Yang, J., Anishchenko, I., Park, H., Peng, Z., Ovchinnikov, S., and Baker, D., “Improved protein structure prediction using predicted interresidue orientations”, Proceedings of the National Academy of Sciences, 117(3):1496-1503, 2020]. trRosetta considers a sequence's MSA to predict distance probability volumes as well as relevant interresidue orientations. In contrast to the present invention, trRosetta relies on restraints derived from the predicted distance and orientations for downstream Rosetta minimization protocols [Rohl, C. A., Strauss, C. E., Misura, K. M., and Baker, D., “Protein structure prediction using Rosetta”, Methods in Enzymology, pages 66-93, Elsevier, 2004]. For each distance, trRosetta's best prediction is considered to be its expected value or its maximum likelihood estimate. dRMSD (distogram RMSD) between predicted distances and ground truth was utilized as the evaluation metric. To make a direct comparison, only distances that lie within trRosetta's binning range (2-20 Å) were considered.
[0036]
[0037] Table 1 presents a comparison of CASP13 Free Modeling benchmarks of dRMSD for the architecture of the present invention's induced distances and trRosetta's expectation and argmax distances, against ground truth, considering only distances that lie within trRosetta's binning range.
TABLE-US-00001 TABLE 1 T0987 T0969 T0955d1 T0998 Graph Transformer 3.722 3.080 5.346 3.476 trRosetta (argmax) 2.135 1.583 2.400 1.482 trRosetta 1.638 1.288 2.160 1.247 (expectation) T0990 T0958d1 T0968s2d1 T0963d2 Graph Transformer 3.017 2.886 3.380 7.853 trRosetta (argmax) 1.356 1.947 1.927 4.039 trRosetta 1.078 1.796 1.695 2.982 (expectation) T0953s2d3 T1010 T0968s1d1 T0957s2d1 Graph Transformer 5.404 4.002 3.905 2.559 trRosetta (argmax) 4.647 2.048 2.226 1.700 trRosetta 3.681 1.531 1.797 1.492 (expectation) T0950 T0953s1 T0953s2 T1022s1 Graph Transformer 3.392 2.698 4.158 2.604 trRosetta (argmax) 1.542 1.897 3.868 1.665 trRosetta 1.148 1.618 3.144 1.433 (expectation)
[0038] These results demonstrate that the Graph Transformer model, despite its size, is competitive to trRosetta's estimates. It is worth noting that the architecture of the present invention resolves backbone structure as its main output and uniquely and deterministically produces distances, whereas trRosetta operates within a probabilistic domain that does not need three-dimensional resolution. These results thus suggest potential for improved predictive capability with larger model capacity and downstream protein refinement.
[0039] Importantly, in contrast to existing approaches, the present invention is highly computationally efficient and can be performed using a fairly small cluster of machines.
[0040] The present invention revisits the protein folding problem and highlights the role of unsupervised language models in providing a meaningful basis for the sequence-to-structure prediction task. It provides a strategy to encapsulate MSA Transformer embeddings and attention traces in a geometric framework, and formalize a graph learning pipeline to reason positional information.
[0041] Overall, the results demonstrate the remarkably expressive power of language models and, in particular, of MSA-augmented architectures. To demonstrate a versatile bridge between sequence and three-dimensional structure, a downstream model was trained to produce C-traces which, before any refinement is performed, induce distograms with high similarity to ground truth.
[0042] The model, in its currently preferred embodiment, tackles only a step of the protein structure prediction problem. With only 382K parameters, it serves as a fast and scalable solution to resolving the position of protein backbones. Furthermore, it extends learning beyond distogram prediction and provides a natural foundation for downstream tasks, such as side chain prediction and protein refinement. It is hypothesized that, by increasing model capacity, dataset size, and training time, the model's predictive capability can improve significantly.
[0043] The present invention builds upon recent groundbreaking work in protein representation learning and protein language modeling. The integration of diverse network architectures and pretrained models, as demonstrated by the present invention, will enable the eventual efficient solution of the protein structure prediction problem.
[0044] At least the following aspects, implementations, modifications, and applications of the described technology are contemplated by the inventors and are considered to be aspects of the presently claimed invention:
[0045] (1) Methods of folding a protein sequence in silico employing attention-based graph transformer architectures.
[0046] (2) Refinement of structures determined via the method of (1), utilizing physical and molecular simulations, in silico relaxation, and 3D roto-translation equivariant attention networks (SE3 transformers), according to techniques known in the art of the invention.
[0047] Some aspects of the invention incorporate methodologies that are disclosed via reference to one or more cited references. These methodologies are described in detail in one or more of the cited references, all of which are incorporated by reference herein.
[0048] While preferred embodiments of the invention are disclosed herein, many other implementations will occur to one of ordinary skill in the art and are all within the scope of the invention. Each of the various embodiments described above may be combined with other described embodiments in order to provide multiple features. Furthermore, while the foregoing describes a number of separate embodiments of the apparatus and method of the present invention, what has been described herein is merely illustrative of the application of the principles of the present invention. Other arrangements, methods, modifications, and substitutions by one of ordinary skill in the art are therefore also considered to be within the scope of the present invention