Full Attention with Sparse Computation Cost
20230022151 · 2023-01-26
Inventors
- Hanjun Dai (San Jose, CA, US)
- Bo Dai (San Jose, CA, US)
- Hongyu Ren (Stanford, CA, US)
- Dale Eric Schuurmans (Edmonton, CA)
- Zihang Dai (Cupertino, CA, US)
- Mengjiao Yang (Berkeley, CA, US)
Cpc classification
International classification
Abstract
The present disclosure is directed to machine learning model architectures which provide full attention capability in each attention head while maintaining low computation and memory complexity. Specifically, according to one aspect of the present disclosure, example attention models provided herein can treat the self-attention mechanism as a conditional expectation over embeddings at each location and approximate the conditional distribution with a structured factorization. Each location can attend to all other locations, either via direct attention, or through indirect attention to group representations, which are again conditional expectations of embeddings from corresponding local regions.
Claims
1. A computing system for performing an attention mechanism with reduced computational requirements, the computing system comprising: one or more processors; and one or more non-transitory computer-readable media that collectively store a machine-learned attention model configured to receive and process a model input to generate a model output, wherein the machine-learned attention model comprises one or more attention layers, wherein at least one of the attention layers comprises one or more attention heads, and wherein at least one of the attention heads is configured to: receive a sequence of input data elements; and apply a structured attention pattern to the sequence of input data elements to generate a sequence of output data elements; wherein, for each input data element in the sequence of input data elements, the structured attention pattern specifies one or more locations of direct expectation and one or more groups of locations of local expectation; and wherein, for each of the one or more groups of locations with local expectation, the at least one of the attention heads is configured to: determine a single group probability for the group of locations; and determine an individual local expectation for each location in the group of locations.
2. The computing system of claim 1, wherein the structured attention pattern comprises a full attention pattern that has a support that covers an entirety of the sequence of input data elements.
3. The computing system of claim 1, wherein the at least one of the attention heads is configured to re-use the individual local expectation for each location in the group of locations when applying the structured attention pattern for two or more different input data elements in the sequence of input data elements.
4. The computing system of claim 1, wherein the sequence of input data elements comprises a sequence of input embeddings.
5. The computing system of claim 1, wherein structured attention pattern specifies a plurality of groups of locations of local expectation.
6. The computing system of claim 1, wherein the machine-learned attention model comprises a plurality of attention layers, wherein each of the plurality of attention layers comprises a plurality of attention heads, and wherein each of the plurality of attention heads is configured to apply the structured attention pattern.
7. The computing system of claim 1, wherein the structured attention pattern comprises a partition tree having two or more hierarchical partition levels.
8. The computing system of claim 1, wherein the at least one of the attention heads is configured to: for each of the one or more groups of locations with local expectation, normalize the individual local expectations for the group of locations; and normalize the one or more locations of direct expectation and the single group probabilities for the one or more groups of locations.
9. The computing system of claim 1, wherein the structured attention pattern comprises a combiner-fixed attention pattern.
10. The computing system of claim 1, wherein the structured attention pattern comprises a combiner-logsparse attention pattern.
11. The computing system of claim 1, wherein the structured attention pattern comprises a combiner-axial attention pattern.
12. The computing system of claim 1, wherein the structured attention pattern comprises a machine-learned factorization plan that specifies the one or more locations of direct expectation and the one or more groups of locations of local expectation.
13. The computing system of claim 1, wherein the model input comprises natural language data.
14. The computing system of claim 1, wherein the model input comprises image data, audio data, protein data, or computer-readable code data.
15. A computer-implemented method for performing an attention mechanism with reduced computational requirements, the method comprising: receiving a sequence of input data elements; and applying a structured attention pattern to each of the sequence of input data elements to generate a sequence of output data elements, wherein applying the structured attention pattern to each input data element comprises: determining one or more locations of direct expectation and one or more groups of locations of local expectation; and for each of the one or more locations of direct expectation, determining a direct expectation; and for each of the one or more groups of locations with local expectation: determining a single group probability for the group of locations; and determining an individual local expectation for each location in the group of locations.
16. The computer-implemented method of claim 15, wherein the structured attention pattern has a support that covers an entirety of the sequence of input data elements.
17. The computer-implemented method of claim 15, wherein, for at least one of the one or more groups of locations with local expectation, determining the individual local expectation for each location in the group of locations comprises re-using the individual local expectation for each location in the group of locations that was previously computed for a different input data element in the sequence of input data elements.
18. The computer-implemented method of claim 15, wherein the sequence of input data elements comprises a sequence of input embeddings.
19. The computer-implemented method of claim 15, wherein the structured attention pattern specifies a plurality of groups of locations of local expectation.
20. One or more non-transitory computer-readable media that collectively store: a machine-learned attention model configured to receive and process a model input to generate a model output, wherein the machine-learned attention model comprises one or more attention layers, wherein at least one of the attention layers comprises one or more attention heads, and wherein at least one of the attention heads is configured to: receive a sequence of input data elements; and apply a structured attention pattern to the sequence of input data elements to generate a sequence of output data elements; wherein, for each input data element in the sequence of input data elements, the structured attention pattern specifies one or more locations of direct expectation and one or more groups of locations of local expectation; and wherein, for each of the one or more groups of locations with local expectation, the at least one of the attention heads is configured to: determine a single group probability for the group of locations; and determine an individual local expectation for each location in the group of locations.
Description
BRIEF DESCRIPTION OF THE DRAWINGS
[0014] Detailed discussion of embodiments directed to one of ordinary skill in the art is set forth in the specification, which makes reference to the appended figures, in which:
[0015]
[0016]
[0017]
[0018]
[0019]
[0020] Reference numerals that are repeated across plural figures are intended to identify the same features in various implementations.
DETAILED DESCRIPTION
Overview
[0021] Generally, the present disclosure is directed to machine learning model architectures which provide full attention capability in each attention head while maintaining low computation and memory complexity. Specifically, according to one aspect of the present disclosure, example attention models provided herein can treat the self-attention mechanism as a conditional expectation over embeddings at each location and approximate the conditional distribution with a structured factorization. Each location can attend to all other locations, either via direct attention, or through indirect attention to group representations, which are again conditional expectations of embeddings from corresponding local regions. The present disclosure also provides specific example attention patterns for full attention which roughly correspond to certain sparse patterns used in existing sparse transformers and result in the same sub-quadratic cost (L log(L)) or
(L/√{square root over (L)}).
[0022] The systems and methods described herein (example implementations of which can be referred to as “Combiner”) are a drop-in replacement for attention layers in existing transformers and can be easily implemented in common frameworks. Example experimental evaluations on both autoregressive and bidirectional sequence tasks and contained in U.S. Provisional Patent Application No. 63/220,063 demonstrated the effectiveness of this approach, yielding state-of-the-art results on several image and text modeling tasks.
[0023] More particularly, the present disclosure provides an improved attention mechanism which can be used as a drop-in replacement for the vanilla quadratic attention mechanism with sub-quadratic computation and memory cost. The proposed approach can still achieve full attention capability within each head of Multi-Head Attention, unlike approaches that adopt sparse or low-rank approximations. In particular, in some implementations, the standard attention computed at each location can be seen as the conditional expectation of the value embeddings at all feasible locations given the current location.
[0024] Based on such an understanding, the proposed attention mechanism explicitly approximates the conditional distribution through a structured factorization of the probability space. Specifically, given a location x, the probability of attending to location y can be either directly calculated via the query vector of x and key vector of y, or indirectly through a local group-based approach where x first attends to the key vector that represents a group of locations containing y, and then multiplying the probability of choosing y within that group. Example implementations of this approach can be referred to as Combiner since the conditional distributions in attention become a combination between several local attentions and direct attentions. This structured decomposition enables the proposed attention mechanism to take existing sparse attention patterns and convert them into corresponding design choices for probability factorizations that achieve full attention.
[0025] Example implementations of the present disclosure can achieve full attention with the same asymptotic complexity as sparse variants. The proposed attention mechanism can be easily implemented in most existing deep learning frameworks without the need for specialized hardware implementation and is GPU/TPU friendly. In fact, both the fixed and learnable sparse attention patterns from many existing Transformer variants can be enhanced with such structured factorizations, with the same order of time or memory cost.
[0026] Example experiments contained in U.S. Provisional Patent Application No. 63/220,063 validate Combiner on both autoregressive and bidirectional sequence modeling tasks over a variety of domains including text and images. The experiments show that Combiner can achieve better perplexity and accuracy when using the same transformer architectures while being much faster in terms of runtime, and achieves state of the art performance on density estimation on standard datasets CIFAR-10 (2.77 bits/dim) and ImageNet-64 (3.42 bits/dim), as well as the Long-Range Arena.
[0027] The systems and methods of the present disclosure provide a number of technical effects and benefits. As one example, the systems and methods of the present disclosure can enable full attention to be performed over long sequences with reduced computational cost, thereby resulting in savings of computational resources such as reduced memory usage, reduced processor usage, etc. The ability to perform full attention at reduced computational cost also provides for better performance (e.g., accuracy) from a machine-learned model in situations where a large input length previously foreclosed the use of full attention due to computational cost. Thus, the systems and methods of the present disclosure both improve the performance of the model and computer itself while also enabling conservation of computing resources.
[0028] With reference now to the Figures, example embodiments of the present disclosure will be discussed in further detail.
Example Attention Models
[0029]
[0030] The machine-learned attention model 12 can be configured to receive and process a model input 14 to generate a model output 16. The model input 14 can be any form of data including raw textual or natural language data, textual or natural language embeddings, audio data, image data, sensor data, protein data, and/or other forms of data such as various sequences of data.
[0031] The machine-learned attention model 12 can include one or more attention layers (illustrated as example attention layers 18, 20, and 22). Some or all of the attention layers can include one or more attention heads. For example, attention layer 20 is shown as including four attention heads, including attention head 24. Any number of layers and/or heads can be used.
[0032] Some of all the attention heads (e.g., head 24) can be configured to receive a sequence of input data elements 26 and apply a structured attention pattern to the sequence of input data elements to generate a sequence of output data elements 28.
[0033] According to an aspect of the present disclosure, for each input data element in the sequence of input data elements, the structured attention pattern can specify one or more locations of direct expectation and one or more groups of locations of local expectation. Each group of groups of locations can contain any number of locations. The groups can be the same size (number of locations) or different sizes (numbers of locations).
[0034] The attention head 24 can apply the structured attention pattern as follows: For each of the one or more locations of direct expectation, the attention head 24 can determine a direct expectation. For each of the one or more groups of locations with local expectation, the attention head 24 can: determine a single group probability for the group of locations; and determine an individual local expectation for each location in the group of locations. The single group probability for a group can be determined for the group as a whole or a representative member of the group.
[0035] As examples of this approach,
[0036] Referring to
[0037] In one example, the attention output for a given input x.sub.i can be expressed as A(x.sub.i)=
where Ω.sub.i.sup.0 denotes the set of locations with direct expectation and Ω.sub.i.sup.r denotes the different groups of locations with local expectation, with r being the index of the groups, and v.sub.j being the value of the jth location.
[0038] Thus, in some implementations, applying the structured attention pattern 200 to a given input x.sub.i can include computing a direct expectation for each location of direct expectation in Ω.sub.i.sup.0 (e.g., 202, 204, 206, etc.), computing a single group probability for each Ω.sub.i.sup.r (e.g., group 208), and computing an individual local expectation for each location within one of the group of locations (e.g., a local expectation for location 212, a local expectation for location 214, etc.). The final attention can then be provided as shown in the expression above.
[0039] In some implementations, for example as shown in the expression above, the local expectation may not necessarily be dependent upon the value of x.sub.i and can therefore be re-used for multiple different input elements, thereby reducing the number of computations that need to be performed.
Attention as Conditional Expectation
[0040] This section revisits the formulation of the standard Transformer from the perspective of conditional expectation, which inspires the derivation of Combiner.
[0041] Without loss of generality, this disclosure uses for ease of description a single sequence in the self-attention scenario. Given a sequence of L embeddings X=[x.sub.1, x.sub.2, . . . , x.sub.L], where X∈.sup.L×d and each embedding x.sub.i∈
.sup.d is a d-dimensional vector, the core component of Transformer is the multi-head attention, where each head h is a scaled dot-product attention:
and the attention vector from each head A.sub.h(X) is concatenated and projected:
MultiHeadAttn(X)=[A.sub.1(X),A.sub.2(X), . . . ,A.sub.H(X)]W.sup.o,W.sup.o∈.sup.Hd×d. (2)
[0042] Here H is the total number of heads per Transformer layer. This disclosure describes how to approximate full attention within each head of multi-head attention. For ease of notation, we drop the head index h whenever possible, and use lower-case letters x.sub.i, q.sub.i, k.sub.i, v.sub.i ∈.sup.d to denote rows in X, Q, K, V respectively, which corresponds to a location i in the original sequence of length L. We use [n] to denote the set of positive integers {1, 2, . . . , n}.
[0043] For a position i∈[L], the attention formulation (1) can be viewed as conditional expectation of rows in V. Specifically, since softmax outputs a probability distribution, we can rewrite (1) as
where p(j|i) denotes the conditional probability at position j given the token at position i and the partition function
over support Ω.sub.i. The support Ω.sub.i of p(j|i) defines the set of valid locations that the i-th token can attend to. For instance, the support set in autoregressive language modeling (LM) consists of all previous tokens, i.e., Ω.sub.i.sup.LM=[i]; in masked language modeling (MLM) the support consists of all tokens in the sequence, i.e., Ω.sub.i.sup.MLM=[L]. That is, Ω.sub.i.sup.LM and Ω.sub.i.sup.MLM represent the full attention capability respectively in the LM and MLM setting.
Full Attention Via Structured Conditional Expectation
[0044] The complexity of p(j|i) is the bottleneck of the computation for A(x.sub.i). Generally, in existing sparse transformers, the support of p(j|i) is sparsified to reduce the computation and memory complexity, e.g., Ω.sub.i.sup.SparseΩ.sub.i.sup.LM for LM and Ω.sub.i.sup.sparse
Ω.sub.i.sup.MLM for MLM, but this can lead to either reduced capacity or limited applicability. This section introduces the Combiner, which achieves Ω.sub.i.sup.Combiner=Ω.sub.i.sup.LM for LM and Ω.sub.i.sup.Combiner=Ω.sub.i.sup.MLM for MLM, while still maintaining sub-quadratic computation and memory cost. Below we denote Ω.sub.i as the support for full attention if there is no ambiguity or need to distinguish between LM or MLM.
[0045] Local Factorization for Conditional Expectation
[0046] One main idea described herein is to exploit a hierarchical structure for conditional probability modeling in Eq. (3), which provides the opportunity for reducing computation complexity while maintaining the same support. Specifically, we introduce support variables Ω.sub.i.sup.r, for r=0, . . . n.sub.i and i∈[L]. The support variables are disjoint, i.e., Ω.sub.i.sup.r∩Ω.sub.i.sup.s=Ø, ∀r≠s, and ∪.sub.r=0.sup.n.sup.
where r.sub.j denotes the index of the support to which j belongs. The last equation arises from the fact that the Ω.sub.i.sup.r are disjoint from each other (Ω.sub.i.sup.r∩Ω.sub.i.sup.s=Ø, ∀r≠s). Therefore, there is only one support, Ω.sub.i.sup.r.sup.Ω.sub.i.sup.r for r≠r.sub.j, are all zero since p(j|Ω.sub.i.sup.r,i)=0.
[0047] Furthermore, assume Ω.sub.i.sup.r.sup.
p(j|i)=p(j|Ω.sub.i.sup.r.sup.
[0048] Given the partition {Ω.sub.i.sup.r}.sub.r=0.sup.n.sup.
where we consider direct attention in partition Ω.sub.i.sup.0 and apply the local factorization (5) to the partition r=1, . . . , n.sub.i. Here {tilde over (p)}(j|i)∝p(j|i) but with different normalization constants, which will be explained below. We refer to this model as Combiner since the structured attention (7) combines the direct expectation of Ω.sub.i.sup.0 and multiple local expectations via p(j|Ω.sub.i.sup.r) and p(Ω.sub.i.sup.r|i) to form the final conditional expectation.
[0049] Equivalently, we can also rewrite the structured attention (7) as
where (⋅) is a binary indicator function. After reordering, one can see from (8) that we obtain the effective conditional probability q(j|i) that tries to approximate the original p(j|i). Each probability term depends on both current location i and other location j, and the expectation is still obtained with respect to a valid conditional probability (non-negative and sums up to 1 over Ω.sub.i).
[0050] Requirement for Sub-quadratic Cost. The benefit of this formulation can immediately be seen from the fact that the local expectation in (7) is independent of the position i. The full dependence is achieved via the multiplier p(Ω.sub.i.sup.r|i) where j∈Ω.sub.i.sup.r. If we can design the local factorization such that:
[0051] 1. the order of number of terms in (7) for p(⋅|i), ∀i∈[L]: Σ.sub.i=1.sup.L(n.sub.i+|Ω.sub.i.sup.0|) is sub-quadratic; and
[0052] 2. let ={Ω.sub.i.sup.r}.sub.i∈[L],r∈[1,n.sub.
| (i.e., the number of unique partitions in
) is sub-quadratic;
[0053] 3. the order of total number of unique calculations of local expectation across all locations in (7), |Ω| is sub-quadratic;
[0054] Then, one can see that the overall computation and memory cost will be sub-quadratic with full attention support Ω.sub.i.sup.Combiner=Ω.sub.i, ∀i∈[L].
[0055] Remark (Further Hierarchical Decomposition): The local decomposition with a one layer partition of support of p(⋅|i) is introduced for simplicity. In fact, such local decompositions can be stacked further, which introduces a partition tree. Specifically, we can further partition Ω.sub.i.sup.r with disjoint subsets {Ω.sub.i.sup.rk}.sub.k=1.sup.n.sup.
[0056] Parameterizing Conditional Probabilities
[0057] While we obtained a possible way to speed up the standard Transformer via a combination of direct expectation and local expectations, it is also beneficial to have an efficient design choice for the probability terms in (7), namely {tilde over (p)}(j|i) from direct expectation, p(j|Ω.sub.i.sup.r) from local expectation and p(Ω.sub.i.sup.r|i) for r∈[1, n.sub.i]. For simplicity and as an example, one can use the scaled dot-product, which means that we will associate positions i,j and variable sets Ω.sub.i.sup.r with the corresponding embedding representation, and thus the probability is proportional to the exponential of the embedding inner products. Specifically:
[0058] {tilde over (p)}(j|i): As this term is for the direct expectation, we can let
which is the same as vanilla attention (3) but with different normalizations, which will be explained in Equation 9.
[0059] p(Ω.sub.i.sup.r|i): This term aims to capture the joint event probability,
Thus the design choice of k.sub.Ω.sub.
[0060] p(j|Ω.sub.i.sup.r): This term is the probability of getting j within this local span Ω.sub.i.sup.r. We make
where we use max pooling or DeepSets over {q.sub.j}.sub.j∈Ω.sub.
[0061] Normalizing Probability Terms. The terms in each local expectation p(j|Ω.sub.i.sup.r), ∀j∈Ω.sub.i.sup.r can be normalized within the local span; the direct expectation {tilde over (p)}(j|i) and the terms in p(Ω.sub.i.sup.r|i) can be normalized together,
and Z(x.sub.i) is the normalizing constant when calculating {tilde over (p)}(j|i) and p(Ω.sub.i.sup.r|i).
[0062] Example Trade-Offs
[0063] Combiner achieves full attention with reduced cost without making explicit sparsity or low-rank assumptions over the attention matrix. However this efficiency gain is not free. This section discusses the limitations of the simplification made by Combiner, and provides a simple workaround.
[0064] Structured Attention Approximation.
[0065] We obtain the local decomposition (5) under the conditional independence assumption. Therefore, the local expectation in (7) is independent of the position i, this suggests that any two locations i.sub.1 and i.sub.2 with Ω.sub.i.sub.
In other words, the rank of the sub-matrix over the same partition in the resulting attention matrix is 1, therefore, the attention matrix is locally low-rank based on the partition. On the other hand, the direct expectation fully attends to each position in sub-support Ω.sub.0, which ensures the full-rank block. These two attention schemes make the attention matrix of Combiner structured. Compared with low-rank approximation for attention, a structured approximation that exploits both the locally low-rank and full-rank blocks has been proved more powerful theoretically and empirically in large-scale kernel machines.
[0066] Improving Expressiveness Using a Mixture Model.
[0067] One way to further improve the expressiveness of the local factorization is to use a mixture model. This idea obtains high-rank softmax layer in language modeling. Let ω be a certain partition of the support (i.e., collection of Ω.sub.i.sup.r) of Ω.sub.i, then one can easily use
to compute the attention, where each component of the mixture A(x.sub.i; ω.sub.m) is the term (7) using a specific factorization plan ω.sub.m. Empirically it was found that two components are already sufficient to improve performance.
Example Instantiations
[0068] This section shows several example local factorization schemes satisfying the requirements described herein. As shown, Combiner is able to convert several sparse transformers into full attention, with the same order of computation and memory consumption. One can also design other factorization patterns, which can be easily instantiated in Combiner.
[0069] Combiner-Fixed
[0070] The Sparse Transformer is one of the most representative variants that can achieve (L√{square root over (L)}) computation and memory cost with sparse attention. See Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
[0071] Here it is shown how to convert a fixed pattern into a factorization plan, and instantiate a full attention variant named the Combiner-Fixed (
[0072] In the fixed-sparse attention, the support is Ω.sub.i.sup.sparseMLM={j: j mod s=0}∪{j: j≡i (div s)} where s is a hyper-parameter, div is integer division, and j≡i (div s) denotes that the quotients of i and j w.r.t. s are the same. In the autoregressive case, Ω.sub.i.sup.sparseLM=Ω.sub.i.sup.sparseMLM∩[i]. Please refer to
[0073] Our design of ω.sub.fixed.sup.MLM has the following form:
Ω.sub.i.sup.0={j:j≡i(div s)},Ω.sub.i.sup.r={j:j div s=r,j.Math.Ω.sub.i.sup.0,∀r∈[L div s],∀i∈[L] (10)
where each local expectation is performed in each span of size s, and there are totally L div s spans across all locations. For each position i∈[L], there are (s+(L div s)) terms in (7); the local expectation has (L div s) terms. The overall complexity is .Math.(s+2(L div s))). The optimal s is
(√{square root over (L)}), and we can achieve
(L√{square root over (L)}) computation and memory complexity, which is the same as [14] but here we gain full attention capability in each attention head. For the LM case, we can simply have ω.sub.fixed.sup.LM: {Ω.sub.i.sup.r∩[i]|Ω.sub.i.sup.r∈ω.sub.fixed.sup.MLM}, which has the same
(L√{square root over (L)}) optimal complexity.
[0074] Combiner-Logsparse
[0075] The Logsparse Transformer is proposed in Shiyang Li, Xiaoyong Jin, Yao Xuan, Xiyou Zhou, Wenhu Chen, Yu-Xiang Wang, and Xifeng Yan. Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. In Advances in Neural Information Processing Systems (NeurIPS), 2019.
[0076] The Logsparse Transformer can theoretically achieve (L log L) cost. The general idea is to make the size of support Ω.sub.i.sup.sparse no larger than ┌log.sub.2 i┐. For the ease of notation, we first define bits(n)=[b.sub.1, b.sub.2, . . . , ┌b.sub.log .sub.
[0077] To exploit this scheme in the Combiner framework, we can define ┌log.sub.2n┐ non-overlapping supports, where Ω.sub.i.sup.r=[suff.sub.r]\[suff.sub.r+1] with the boundary case [suff.sub.┌log.sub.
summaries. Each location i will select at most (log(i)) non-overlapping spans to cover the full support Ω.sub.i, and thus, the total cost will be
(L log L).
[0078] Combiner-Axial
[0079] The Axial Transformer is described at Jonathan Ho, Nal Kalchbrenner, Dirk Weissenborn, and Tim Salimans. Axial attention in multidimensional transformers. arXiv preprint arXiv:1912.12180, 2019.
[0080] The Axial Transformer builds the attention along each axis of the input data. Without loss of generality, we focus on 2D case where the input sequence is reshaped into a matrix of size n×m=L. Specifically, the location i in original sequence will be in row.sub.i=(i−1) div m+1 and col.sub.i=(i−1) mod m+1. We show how to simply enable full attention with factorization on 2D matrix, hence Combiner-Axial.
[0081] The sparse axial has Ω.sub.i.sup.sparseMLM={j: j−1≡i−1(mod m)}∪{j: j−1≡i−1(div m)}, and Ω.sub.i.sup.sparseLM=Ω.sub.i.sup.sparseMLM ∩[i], which all have at most O(m+n) entries for each i, as illustrated in
[0082] ω.sub.axial-vertical.sup.LM: Ω.sub.i.sup.0=Ω.sub.i.sup.sparseLM, and Ω.sub.i.sup.r={j: j≡r(mod m)}∩[i−col.sub.i], for r∈[m]\ col.sub.i. As depicted in
[0083] ω.sub.axial-horizontal.sup.LM: similar as ω.sub.axial-vertical except that each Ω.sub.i.sup.r summarizes the row r before row.sub.i and excludes col.sub.i
[0084] ω.sub.axial-rowmajor.sup.LM: Ω.sub.i.sup.0=j: j−1 ≡i−1(div m)∩[i], i.e., elements in the same row are directly attended, while Ω.sub.i.sup.r={j: j≡r(div m)}∩[i−col.sub.i] captures the rows before row.sub.i. This structure is similar to Combiner-Fixed, except for the way that the abstraction (and thus the local expectation) is computed. Combiner-Fixed computes the abstraction only based on r of partition Ω.sub.i.sup.r, where to ω.sub.axial-rowmajor depends on both r and the column col.sub.i (
[0085] In all cases above, the cost is similar to the Axial Transformer, which is O(L√{square root over (L)}) if we reshape the sequence to a 2D matrix with n, m=O(√{square root over (L)}).
[0086] Combiner-Learnable
[0087] Another example implementation can also learn the factorization plan co from the data. We illustrate this with Routing Transformer and provide a way to enable full attention in Routing Transformer following the Combiner principle.
[0088] For a specific layer, suppose we have a learned disjoint region (or cluster in Routing Transformer) {Ω.sup.r}.sub.r=1.sup.n where ∪.sub.rΩ.sup.r=[L]. In Routing Transformer, we simply have Ω.sub.i.sup.sparseMLM=∩.sup.r.sup.
ω.sub.routingMLM:Ω.sub.i.sup.0=Ω.sub.i.sup.r.sup.
[0089] Note that n.sub.i=n (, number of learned clusters) for all locations. The above factorization can only work for MLM. LM requires the following definition:
ω.sub.routingMLM:Ω.sub.i.sup.0=Ω.sub.i.sup.r.sup.
[0090] In general, both LM and MLM can have sub-quadratic cost when n=O(√{square root over (L)}). However, routing variants (including the Routing Transformer) require a gather operation, which can be slow on TPUs. Routing Transformer is described at Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. Efficient content-based sparse attention with routing transformers. Transactions of the Association for Computational Linguistics, 9:53-68, 2021.
Example Devices and Systems
[0091]
[0092] The user computing device 102 can be any type of computing device, such as, for example, a personal computing device (e.g., laptop or desktop), a mobile computing device (e.g., smartphone or tablet), a gaming console or controller, a wearable computing device, an embedded computing device, or any other type of computing device.
[0093] The user computing device 102 includes one or more processors 112 and a memory 114. The one or more processors 112 can be any suitable processing device (e.g., a processor core, a microprocessor, an ASIC, an FPGA, a controller, a microcontroller, etc.) and can be one processor or a plurality of processors that are operatively connected. The memory 114 can include one or more non-transitory computer-readable storage media, such as RAM, ROM, EEPROM, EPROM, flash memory devices, magnetic disks, etc., and combinations thereof. The memory 114 can store data 116 and instructions 118 which are executed by the processor 112 to cause the user computing device 102 to perform operations.
[0094] In some implementations, the user computing device 102 can store or include one or more machine-learned models 120. For example, the machine-learned models 120 can be or can otherwise include various machine-learned models such as neural networks (e.g., deep neural networks) or other types of machine-learned models, including non-linear models and/or linear models. Neural networks can include feed-forward neural networks, recurrent neural networks (e.g., long short-term memory recurrent neural networks), convolutional neural networks or other forms of neural networks. Some example machine-learned models can leverage an attention mechanism such as self-attention. For example, some example machine-learned models can include multi-headed self-attention models (e.g., transformer models). Example machine-learned models 120 are discussed with reference to
[0095] In some implementations, the one or more machine-learned models 120 can be received from the server computing system 130 over network 180, stored in the user computing device memory 114, and then used or otherwise implemented by the one or more processors 112. In some implementations, the user computing device 102 can implement multiple parallel instances of a single machine-learned model 120 (e.g., to perform parallel processing across multiple instances of inputs).
[0096] Additionally or alternatively, one or more machine-learned models 140 can be included in or otherwise stored and implemented by the server computing system 130 that communicates with the user computing device 102 according to a client-server relationship. For example, the machine-learned models 140 can be implemented by the server computing system 140 as a portion of a web service. Thus, one or more models 120 can be stored and implemented at the user computing device 102 and/or one or more models 140 can be stored and implemented at the server computing system 130.
[0097] The user computing device 102 can also include one or more user input components 122 that receives user input. For example, the user input component 122 can be a touch-sensitive component (e.g., a touch-sensitive display screen or a touch pad) that is sensitive to the touch of a user input object (e.g., a finger or a stylus). The touch-sensitive component can serve to implement a virtual keyboard. Other example user input components include a microphone, a traditional keyboard, or other means by which a user can provide user input.
[0098] The server computing system 130 includes one or more processors 132 and a memory 134. The one or more processors 132 can be any suitable processing device (e.g., a processor core, a microprocessor, an ASIC, an FPGA, a controller, a microcontroller, etc.) and can be one processor or a plurality of processors that are operatively connected. The memory 134 can include one or more non-transitory computer-readable storage media, such as RAM, ROM, EEPROM, EPROM, flash memory devices, magnetic disks, etc., and combinations thereof. The memory 134 can store data 136 and instructions 138 which are executed by the processor 132 to cause the server computing system 130 to perform operations.
[0099] In some implementations, the server computing system 130 includes or is otherwise implemented by one or more server computing devices. In instances in which the server computing system 130 includes plural server computing devices, such server computing devices can operate according to sequential computing architectures, parallel computing architectures, or some combination thereof.
[0100] As described above, the server computing system 130 can store or otherwise include one or more machine-learned models 140. For example, the models 140 can be or can otherwise include various machine-learned models. Example machine-learned models include neural networks or other multi-layer non-linear models. Example neural networks include feed forward neural networks, deep neural networks, recurrent neural networks, and convolutional neural networks. Some example machine-learned models can leverage an attention mechanism such as self-attention. For example, some example machine-learned models can include multi-headed self-attention models (e.g., transformer models). Example models 140 are discussed with reference to
[0101] The user computing device 102 and/or the server computing system 130 can train the models 120 and/or 140 via interaction with the training computing system 150 that is communicatively coupled over the network 180. The training computing system 150 can be separate from the server computing system 130 or can be a portion of the server computing system 130.
[0102] The training computing system 150 includes one or more processors 152 and a memory 154. The one or more processors 152 can be any suitable processing device (e.g., a processor core, a microprocessor, an ASIC, an FPGA, a controller, a microcontroller, etc.) and can be one processor or a plurality of processors that are operatively connected. The memory 154 can include one or more non-transitory computer-readable storage media, such as RAM, ROM, EEPROM, EPROM, flash memory devices, magnetic disks, etc., and combinations thereof. The memory 154 can store data 156 and instructions 158 which are executed by the processor 152 to cause the training computing system 150 to perform operations. In some implementations, the training computing system 150 includes or is otherwise implemented by one or more server computing devices.
[0103] The training computing system 150 can include a model trainer 160 that trains the machine-learned models 120 and/or 140 stored at the user computing device 102 and/or the server computing system 130 using various training or learning techniques, such as, for example, backwards propagation of errors. For example, a loss function can be backpropagated through the model(s) to update one or more parameters of the model(s) (e.g., based on a gradient of the loss function). Various loss functions can be used such as mean squared error, likelihood loss, cross entropy loss, hinge loss, and/or various other loss functions. Gradient descent techniques can be used to iteratively update the parameters over a number of training iterations.
[0104] In some implementations, performing backwards propagation of errors can include performing truncated backpropagation through time. The model trainer 160 can perform a number of generalization techniques (e.g., weight decays, dropouts, etc.) to improve the generalization capability of the models being trained.
[0105] In particular, the model trainer 160 can train the machine-learned models 120 and/or 140 based on a set of training data 162. In some implementations, if the user has provided consent, the training examples can be provided by the user computing device 102. Thus, in such implementations, the model 120 provided to the user computing device 102 can be trained by the training computing system 150 on user-specific data received from the user computing device 102. In some instances, this process can be referred to as personalizing the model.
[0106] The model trainer 160 includes computer logic utilized to provide desired functionality. The model trainer 160 can be implemented in hardware, firmware, and/or software controlling a general purpose processor. For example, in some implementations, the model trainer 160 includes program files stored on a storage device, loaded into a memory and executed by one or more processors. In other implementations, the model trainer 160 includes one or more sets of computer-executable instructions that are stored in a tangible computer-readable storage medium such as RAM, hard disk, or optical or magnetic media.
[0107] The network 180 can be any type of communications network, such as a local area network (e.g., intranet), wide area network (e.g., Internet), or some combination thereof and can include any number of wired or wireless links. In general, communication over the network 180 can be carried via any type of wired and/or wireless connection, using a wide variety of communication protocols (e.g., TCP/IP, HTTP, SMTP, FTP), encodings or formats (e.g., HTML, XML), and/or protection schemes (e.g., VPN, secure HTTP, SSL).
[0108] The machine-learned models described in this specification may be used in a variety of tasks, applications, and/or use cases.
[0109] In some implementations, the input to the machine-learned model(s) of the present disclosure can be image data. The machine-learned model(s) can process the image data to generate an output. As an example, the machine-learned model(s) can process the image data to generate an image recognition output (e.g., a recognition of the image data, a latent embedding of the image data, an encoded representation of the image data, a hash of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an image segmentation output. As another example, the machine-learned model(s) can process the image data to generate an image classification output. As another example, the machine-learned model(s) can process the image data to generate an image data modification output (e.g., an alteration of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an encoded image data output (e.g., an encoded and/or compressed representation of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an upscaled image data output. As another example, the machine-learned model(s) can process the image data to generate a prediction output.
[0110] In some implementations, the input to the machine-learned model(s) of the present disclosure can be text or natural language data. The machine-learned model(s) can process the text or natural language data to generate an output. As an example, the machine-learned model(s) can process the natural language data to generate a language encoding output. As another example, the machine-learned model(s) can process the text or natural language data to generate a latent text embedding output. As another example, the machine-learned model(s) can process the text or natural language data to generate a translation output. As another example, the machine-learned model(s) can process the text or natural language data to generate a classification output. As another example, the machine-learned model(s) can process the text or natural language data to generate a textual segmentation output. As another example, the machine-learned model(s) can process the text or natural language data to generate a semantic intent output. As another example, the machine-learned model(s) can process the text or natural language data to generate an upscaled text or natural language output (e.g., text or natural language data that is higher quality than the input text or natural language, etc.). As another example, the machine-learned model(s) can process the text or natural language data to generate a prediction output.
[0111] In some implementations, the input to the machine-learned model(s) of the present disclosure can be speech data. The machine-learned model(s) can process the speech data to generate an output. As an example, the machine-learned model(s) can process the speech data to generate a speech recognition output. As another example, the machine-learned model(s) can process the speech data to generate a speech translation output. As another example, the machine-learned model(s) can process the speech data to generate a latent embedding output. As another example, the machine-learned model(s) can process the speech data to generate an encoded speech output (e.g., an encoded and/or compressed representation of the speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate an upscaled speech output (e.g., speech data that is higher quality than the input speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate a textual representation output (e.g., a textual representation of the input speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate a prediction output.
[0112] In some implementations, the input to the machine-learned model(s) of the present disclosure can be latent encoding data (e.g., a latent space representation of an input, etc.). The machine-learned model(s) can process the latent encoding data to generate an output. As an example, the machine-learned model(s) can process the latent encoding data to generate a recognition output. As another example, the machine-learned model(s) can process the latent encoding data to generate a reconstruction output. As another example, the machine-learned model(s) can process the latent encoding data to generate a search output. As another example, the machine-learned model(s) can process the latent encoding data to generate a reclustering output. As another example, the machine-learned model(s) can process the latent encoding data to generate a prediction output.
[0113] In some implementations, the input to the machine-learned model(s) of the present disclosure can be statistical data. Statistical data can be, represent, or otherwise include data computed and/or calculated from some other data source. The machine-learned model(s) can process the statistical data to generate an output. As an example, the machine-learned model(s) can process the statistical data to generate a recognition output. As another example, the machine-learned model(s) can process the statistical data to generate a prediction output. As another example, the machine-learned model(s) can process the statistical data to generate a classification output. As another example, the machine-learned model(s) can process the statistical data to generate a segmentation output. As another example, the machine-learned model(s) can process the statistical data to generate a visualization output. As another example, the machine-learned model(s) can process the statistical data to generate a diagnostic output.
[0114] In some implementations, the input to the machine-learned model(s) of the present disclosure can be sensor data. The machine-learned model(s) can process the sensor data to generate an output. As an example, the machine-learned model(s) can process the sensor data to generate a recognition output. As another example, the machine-learned model(s) can process the sensor data to generate a prediction output. As another example, the machine-learned model(s) can process the sensor data to generate a classification output. As another example, the machine-learned model(s) can process the sensor data to generate a segmentation output. As another example, the machine-learned model(s) can process the sensor data to generate a visualization output. As another example, the machine-learned model(s) can process the sensor data to generate a diagnostic output. As another example, the machine-learned model(s) can process the sensor data to generate a detection output.
[0115] In some cases, the machine-learned model(s) can be configured to perform a task that includes encoding input data for reliable and/or efficient transmission or storage (and/or corresponding decoding). For example, the task may be an audio compression task. The input may include audio data and the output may comprise compressed audio data. In another example, the input includes visual data (e.g. one or more images or videos), the output comprises compressed visual data, and the task is a visual data compression task. In another example, the task may comprise generating an embedding for input data (e.g. input audio or visual data).
[0116] In some cases, the input includes visual data and the task is a computer vision task. In some cases, the input includes pixel data for one or more images and the task is an image processing task. For example, the image processing task can be image classification, where the output is a set of scores, each score corresponding to a different object class and representing the likelihood that the one or more images depict an object belonging to the object class. The image processing task may be object detection, where the image processing output identifies one or more regions in the one or more images and, for each region, a likelihood that region depicts an object of interest. As another example, the image processing task can be image segmentation, where the image processing output defines, for each pixel in the one or more images, a respective likelihood for each category in a predetermined set of categories. For example, the set of categories can be foreground and background. As another example, the set of categories can be object classes. As another example, the image processing task can be depth estimation, where the image processing output defines, for each pixel in the one or more images, a respective depth value. As another example, the image processing task can be motion estimation, where the network input includes multiple images, and the image processing output defines, for each pixel of one of the input images, a motion of the scene depicted at the pixel between the images in the network input.
[0117] In some cases, the input includes audio data representing a spoken utterance and the task is a speech recognition task. The output may comprise a text output which is mapped to the spoken utterance. In some cases, the task comprises encrypting or decrypting input data. In some cases, the task comprises a microprocessor performance task, such as branch prediction or memory address translation.
[0118]
[0119]
[0120] The computing device 10 includes a number of applications (e.g., applications 1 through N). Each application contains its own machine learning library and machine-learned model(s). For example, each application can include a machine-learned model. Example applications include a text messaging application, an email application, a dictation application, a virtual keyboard application, a browser application, etc.
[0121] As illustrated in
[0122]
[0123] The computing device 50 includes a number of applications (e.g., applications 1 through N). Each application is in communication with a central intelligence layer. Example applications include a text messaging application, an email application, a dictation application, a virtual keyboard application, a browser application, etc. In some implementations, each application can communicate with the central intelligence layer (and model(s) stored therein) using an API (e.g., a common API across all applications).
[0124] The central intelligence layer includes a number of machine-learned models. For example, as illustrated in
[0125] The central intelligence layer can communicate with a central device data layer. The central device data layer can be a centralized repository of data for the computing device 50. As illustrated in
Additional Disclosure
[0126] The technology discussed herein makes reference to servers, databases, software applications, and other computer-based systems, as well as actions taken and information sent to and from such systems. The inherent flexibility of computer-based systems allows for a great variety of possible configurations, combinations, and divisions of tasks and functionality between and among components. For instance, processes discussed herein can be implemented using a single device or component or multiple devices or components working in combination. Databases and applications can be implemented on a single system or distributed across multiple systems. Distributed components can operate sequentially or in parallel.
[0127] While the present subject matter has been described in detail with respect to various specific example embodiments thereof, each example is provided by way of explanation, not limitation of the disclosure. Those skilled in the art, upon attaining an understanding of the foregoing, can readily produce alterations to, variations of, and equivalents to such embodiments. Accordingly, the subject disclosure does not preclude inclusion of such modifications, variations and/or additions to the present subject matter as would be readily apparent to one of ordinary skill in the art. For instance, features illustrated or described as part of one embodiment can be used with another embodiment to yield a still further embodiment. Thus, it is intended that the present disclosure cover such alterations, variations, and equivalents.