Private Federated Learning with Reduced Communication Cost
20240281564 ยท 2024-08-22
Inventors
- Peter Kairouz (Seattle, WA, US)
- Christopher Choquette-Choo (Sunnyvale, CA, US)
- Sewoong Oh (Seattle, WA, US)
- Md Enayat Ullah (Baltimore, MD, US)
Cpc classification
International classification
Abstract
New techniques are provided which reduce communication in private federated learning without the need for setting or tuning compression rates. Example on-the-fly methods automatically adjust the compression rate based on the error induced during training, while maintaining provable privacy guarantees through the use of secure aggregation and differential privacy.
Claims
1. A computer-implemented method for federated learning of a global version of a machine learning model with improved communication efficiency, the method comprising: for each of one or more update iterations: communicating, by a server computing system, a current set of adaptive compression parameters to a plurality of client computing devices; receiving, by the server computing system, a plurality of first compressed vectors respectively from the plurality of client computing devices, wherein the first compressed vector received from each client computing device has been generated by performance by the client computing device of a compression algorithm to a respective update vector of the client computing device using the current set of adaptive compression parameters; receiving, by the server computing system, a plurality of second compressed vectors respectively from the plurality of client computing devices, wherein the second compressed vector received from each client computing device has been generated by application by the client computing device of the compression algorithm to the respective update vector of the client computing device using the second set of compression parameters; performing, by the server computing system, a global update to the global version of the machine learning model based at least in part on a first aggregate of the plurality of first compressed vectors; determining, by the server computing system, one or more compression error values based at least in part on a second aggregate of the plurality of second compressed vectors; and updating, by the server computing system, the current set of adaptive compression parameters based at least in part on the one or more compression error values.
2. The computer-implemented method of claim 1, wherein: the current set of adaptive compression parameters has a first compression rate; the second set of adaptive compression parameters has a second compression rate; and the first compression rate is smaller than the second compression rate.
3. The computer-implemented method of claim 1, wherein updating, by the server computing system, the current set of adaptive compression parameters based at least in part on the one or more compression error values comprises: when the one or more compression error values indicate compression error greater than a desired compression error: updating, by the server computing system, the current set of adaptive compression parameters so as to decrease a first compression rate associated with the current set of adaptive compression parameters; and when the one or more compression error values indicate compression error less than the desired compression error: updating, by the server computing system, the current set of adaptive compression parameters so as to increase the first compression rate associated with the current set of adaptive compression parameters.
4. The computer-implemented method of claim 1, wherein determining, by the server computing system, the one or more compression error values based at least in part on the second aggregate of the plurality of second compressed vectors comprises determining, by the server computing system as the one or more compression error values, a norm of the second aggregate of the plurality of second compressed vectors.
5. The computer-implemented method of claim 1, wherein determining, by the server computing system as the one or more compression error values, the norm of the second aggregate of the plurality of second compressed vectors comprises determining, by the server computing system as the one or more compression error values, a differentially private norm of the second aggregate of the plurality of second compressed vectors.
6. The computer-implemented method of claim 1, wherein performing, by the server computing system, the global update to the global version of the machine learning model based at least in part on the first aggregate of the plurality of first compressed vectors comprises: decompressing, by the server computing system, the first aggregate of the plurality of first compressed vectors to obtain an estimated mean; and performing, by the server computing system, the global update to the global version of the machine learning model based at least in part on the estimated mean.
7. The computer-implemented method of claim 6, wherein determining, by the server computing system, the one or more compression error values based at least in part on the second aggregate of the plurality of second compressed vectors comprises: re-compressing, by the server computing system, the estimated mean according to the second set of compression parameters to obtain a re-compressed aggregate; and determining, by the server computing system as the one or more compression error values, a difference between the re-compressed aggregate and the second aggregate.
8. The computer-implemented method of claim 1, wherein the first aggregate of the plurality of first compressed vectors comprises a differentially private aggregate of the plurality of first compressed vectors.
9. The computer-implemented method of claim 1, wherein the first aggregate of the plurality of first compressed vectors comprises a Secure Aggregation (SecAgg) aggregate of the plurality of first compressed vectors.
10. The computer-implemented method of claim 1, wherein the current set of adaptive compression parameters comprise one or more sketch sizes for one or more sketching operations, and wherein the plurality of first compressed vectors comprise sketched vectors generated by application of the one or more sketching operations.
11. The computer-implemented method of claim 1, wherein the second set of compression parameters is fixed for the update iterations.
12. The computer-implemented method of claim 1, wherein the respective update vector of each client computing device describes updates to model parameters of a local version of the machine learning model stored at the client computing device that result from training the local version of the machine learning model on local training data stored at the client computing device.
13. The computer-implemented method of claim 1, wherein the one or more update iterations comprise a plurality of update iterations.
14. A computer-implemented method for distributed mean estimation with improved communication efficiency, the method comprising: for each of one or more update iterations: communicating, by a server computing system, a current set of adaptive compression parameters to a plurality of client computing devices; receiving, by the server computing system, a plurality of first compressed vectors respectively from the plurality of client computing devices, wherein the first compressed vector received from each client computing device has been generated by performance by the client computing device of a compression algorithm to a respective data vector of the client computing device using the current set of adaptive compression parameters; receiving, by the server computing system, a plurality of second compressed vectors respectively from the plurality of client computing devices, wherein the second compressed vector received from each client computing device has been generated by application by the client computing device of the compression algorithm to the respective data vector of the client computing device using the second set of compression parameters; determining, by the server computing system, a estimated mean of the data vectors of the client computing devices based at least in part on a first aggregate of the plurality of first compressed vectors; determining, by the server computing system, one or more compression error values based at least in part on a second aggregate of the plurality of second compressed vectors; and updating, by the server computing system, the current set of adaptive compression parameters based at least in part on the one or more compression error values.
15. A client computing device configured to perform operations, the operations comprising: for each of one or more update iterations: receiving, from a server computing system, a current set of adaptive compression parameters; training a local version of a machine learning model on local training data stored at the client computing device to generate an update vector that describes updates to model parameters of the local version of the machine learning model stored at the client computing device; applying a compression algorithm to the update vector using the current set of adaptive compression parameters to generate a first compressed vector; applying the compression algorithm to the update vector using a second set of compression parameters to generate a second compressed vector; and transmitting the first compressed vector and the second compressed vector to the server computing system.
16. The client computing device of claim 15, wherein, at each update iteration, the current set of adaptive compression parameters has been updated based on one or more compression error values generated based at least in part on the second compressed vector transmitted by the client computing device at the prior update iteration.
17. The client computing device of claim 15, wherein: the current set of adaptive compression parameters has a first compression rate; the second set of adaptive compression parameters has a second compression rate; and the first compression rate is smaller than the second compression rate.
Description
BRIEF DESCRIPTION OF THE DRAWINGS
[0028] Detailed discussion of embodiments directed to one of ordinary skill in the art is set forth in the specification, which makes reference to the appended figures, in which:
[0029]
[0030]
[0031]
[0032]
[0033]
[0034]
[0035]
[0036]
[0037] Reference numerals that are repeated across plural figures are intended to identify the same features in various implementations.
DETAILED DESCRIPTION
Overview
[0038] Example implementations of the present disclosure represent new techniques for reducing communication in federated learning (FL) without the need for setting or tuning compression rates. In particular, the present disclosure provides techniques that solve the problem of designing adaptive (e.g., self-tuning) compression methods in private FL. Example implementations described herein can be performed on-the-fly to automatically adjust the compression rate based on the error induced during training, while maintaining provable privacy guarantees through the use of secure aggregation and differential privacy. The techniques are provably instance-optimal for mean estimation, meaning that they can adapt to the hardness of the problem with minimal interactivity. One common principle behind the proposed techniques is to set compression rate such that the compression error doesn't overwhelm the privacy error. These approaches enable high compression rates with acceptable sacrifice in accuracy. The effectiveness of these approaches has been demonstrated on real-world datasets, achieving near-ideal compression rates without the need for tuning.
[0039] The systems and methods of the present disclosure provide a number of technical effects and benefits. As one example, the proposed techniques provide improved communication efficiency, while preserving privacy, for FL or other distributed mean estimation settings. By providing such improved communication efficiency, consumption of computational resources can be reduced. For example, consumption of memory, processor cycles, and/or network bandwidth can be reduced. These computational benefits can be achieved all while preserving user privacy.
[0040] Examples of embodiments and implementations of the systems and methods of the present disclosure are discussed in the following sections.
Example Implementations of Private Federated Learning with Autotuned Compression
Example Preliminaries
[0041] An algorithm A satisfies (?, ?)-differential privacy if for all datasets D and D differing in one data point and all events E in the range of the algorithm, (A(D)?E)?e.sup.?
P(A(D)?E)+?.
[0042] Secure Aggregation (SecAgg): SecAgg is a cryptographic technique that can be used by multiple parties to compute an aggregate value, such as a sum or average, without revealing individual contributions by each party to the computation.
[0043] Count-mean sketching: In one embodiment, a compression technique based on the sparse Johnson-Lindenstrauss (JL) random matrix/count-mean sketch data structure can be used. The sketching operation can be a linear map, denoted as S: .sup.d.fwdarw.
.sup.PC, where P, C?
are parameters. The corresponding unsketching operation is denoted as U:
.sup.PC.fwdarw.
.sup.d. In some embodiments, a count-sketch data structure can be used. A count-sketch can be linear map, which for j?[P], is denoted as S.sub.j:
.sup.d.fwdarw.
.sup.C. The count-sketch can be described using two hash functions: bucketing hash h.sub.j: [d].fwdarw.[C] and sign hash: s.sub.j: [d].fwdarw.{?1,1}, mapping the q-th coordinate z.sub.q to ?.sub.i=1.sup.d S.sub.j(i)
(h.sub.j(i)=h.sub.j(q))z.sub.i.
[0044] The count-mean sketch construction pads P count sketches to get S: .sup.d.fwdarw.
.sup.PC, mapping z as follows,
[0045] The above, being a JL matrix, approximately preserves norms of z i.e. ||S(z) ||?||z||, which is useful in controlling sensitivity, thus enabling application of DP techniques. The unsketching operation is simply U(S(z))=S.sup.TS(z). This gives an unbiased estimate, [S.sup.TS(z)]=z, whose variance scales as d?z?.sup.2/(PC). This captures the trade-off between compression rate, d/PC, and error.
Example Instance-Optimal Federated Mean Estimation (FME)
[0046] A central operation in standard federated learning (FL) algorithms is averaging the client model updates in a distributed manner. This can be posed as a standard distributed mean estimation (DME) problem with n clients, each with a vector z.sub.i?.sup.d sampled i.i.d. from an unknown distribution D with population mean ?. The goal of the server is to estimate ? while communicating only a small number of bits with the clients at each round. Once a communication-efficient scheme is determined, the communication-efficient scheme can be readily integrated into the learning algorithm of choice, such as FedAvg.
[0047] In some embodiments, in order to provide privacy and security of the clients' data, mean estimation for FL has additional requirements. For example, in some embodiments, clients can only be accessed via SecAgg and DP must be satisfied. This particular approach can be referred to as Federated Mean Estimation (FME). To bound the sensitivity of the empirical mean, {circumflex over (?)}({z.sub.i}.sub.i=1.sup.n):=(1/n) ?.sub.i=1.sup.nz.sub.i, the data is assumed to be bounded by ||z.sub.i|?G. Since gradient norm clipping is almost always used in the DP FL setting, G can be assumed to be known. The first result characterizing the communication cost of FME can be an unbiased estimator based on count-mean sketching that satisfy (?, ?)-differential privacy and (order) optimal error of
with an (order) optimal communication complexity of ?(n.sup.2?.sup.2). The error rate in Equation 1 is optimal as it matches the information theoretic lower bound that holds even without any communication constraints. The communication cost of ?(n.sup.2?.sup.2) cannot be improved for the worst-case data as it matches the lower bound.
Example Adapting to the Instance's Norm
[0048] Example embodiments of the present disclosure illustrate that, with a sketch size of O(PC) that costs O(PC) in per client communication, the following can be achieved:
where the normalized norm of the mean,
?E [0,1] and G is the maximum ?z? as chosen for a sensitivity bound under DP FL. The first error term can capture how the sketching error scales as the norm of the mean, MG. When M is significantly smaller than one, as motivated by the application to FL, a significantly smaller choice of the communication cost, PC=O(minn.sup.2?.sup.2, ndM.sup.2+1/n), is sufficient to achieve the optimal error rate of Equation 1. The dominant term is smaller than the standard sketch size of PC=O(n.sup.2?.sup.2) by a factor of M.sup.2?[0,1]. However, selecting this sketch size requires knowledge of M. This necessitates a scheme that adapts to the current instance by privately estimating M with a small communication cost. Example aspects of the present disclosure therefore include an interactive Adapt Norm FME algorithm.
[0049] The FME algorithm can be instance-optimal with respect to M if it achieves the optimal error rate of Equation 1 with communication complexity of 0(M.sup.2n.sup.2?.sup.2) for every instance whose norm of the mean is bounded by GM. Aspects of the present invention can include a novel adaptive compression scheme and instance optimality.
Example Instance-optimal FME for Norm M
[0050] In some embodiments, a two-round procedure that achieves the optimal error in Equation 1 with instance-optimal communication complexity of
can be utilized, without prior knowledge of M. In one particular aspect, count-mean sketch can be used in the first round, to construct a private yet accurate estimate of M. This can be enabled by the fact the count-mean sketch approximately preserves norms. In the second round, the sketch size can be set based on the estimate of M. Such interactivity is (i) provably necessary for instance-optimal FME and (ii) needed in other problems that target instance-optimal bounds with DP.
[0051] In some embodiments, each client computes a sketch, clips it, and sends it to the server. Each client sends two sketches. The second sketch is used for estimating the statistic to ensure that instance-optimal compression-utility tradeoffs can be achieved. Although only one sketch is needed for the Adapt Norm approach, since the statistic can be directly estimated from the (first) sketch S without a second sketch, allowing for additional minor communication optimization, the use of the two-sketch approach allows the result to be inline with the Adapt Tail approach which requires both sketches.
TABLE-US-00001 Algorithm 1 Adapt Norm FME Input: Sketch sizes C.sub.1, P, {tilde over (C)}, {tilde over (P)}, noise variance {tilde over (?)}, ?, }
1: for j = 1 to 2 do 2: Select n random clients {
}
and broadcast sketching operators S.sub.j and {tilde over (S)} of sizes (C.sub.j, P) and ({tilde over (C)}, {tilde over (P)}), respectively. 3: SecAgg: v.sub.j =SecAgg({
}
) +
(0,
), {tilde over (v)}.sub.j =SecAgg({{tilde over (Q)}.sub.j.sup.(c)}.sub.c=1.sup.n), where Q.sub.j.sup.(c) ? clip.sub.B(S.sub.j(z.sub.c.sup.(j))), {tilde over (Q)}.sub.j.sup.(c) ? clip.sub.B({tilde over (S)}.sub.j(z.sub.c.sup.(j))) 4: Unsketch DP mean: Compute
= clip.sub.B(||{tilde over (v)}.sub.j||) + Laplace({tilde over (?)}) 6:
indicates data missing or illegible when filed
[0052] In some embodiments, the server protocol aggregates the sketches using SecAgg. In the first round, Q.sub.j's are not used, i.e., C.sub.1=0. Only ?.sub.j's are used to construct a private estimate of the norm of the mean {circumflex over (n)}.sub.1 (line 5). There is no need to unsketch ?.sub.j's as the norm is preserved under the sketching operation. This estimate can be used to set the next round's sketch size, C.sub.j+1, so as to ensure the compression error of the same order as privacy and statistical error; this is the same choice made when it was assumed oracle knowledge of M. In the second round and onwards, the first sketch can be used, which is aggregated using SecAgg and privatized by adding appropriately scaled Gaussian noise. Finally, the Q.sub.j's are unsketched to estimate the mean. In some embodiments, the clients need not send {tilde over (Q)}.sub.j's after the first round.
Example Theoretical Analysis
[0053] The Adapt Norm approach can achieve instance-optimality with respect to M; an optimal error rate of Equation 1 can be achieved for every instance of the problem with an efficient communication complexity of O(M.sup.2n.sup.2?.sup.2) The optimality follows from a matching lower bound. interactivity can beis critical in achieving instance-optimality. This follows from a lower bound that proves a fundamental gap between interactive and non-interactive algorithms.
[0054] For any choice of the failure probability
satisfies (?, ?)-DP, the output
[0055] Finally, in some embodiments, the number of rounds is two, and with probability at least 1-?, the total per-client communication complexity is ?(min
[0056] The error rate matches Equation 1. Compared to the target communication complexity of O((M.sup.2+1/n)n.sup.2?.sup.2), the above communication complexity has an additional term 1/(n?).sup.2, which stems from the fact that the norm can only be accessed privately. In the interesting regime where the error is strictly less than the trivial M.sup.2G.sup.2, the extra 1/(n?).sup.2+1/n is smaller than M.sup.2. The resulting communication complexity is O(M.sup.2n.sup.2?.sup.2/log 1/?). This nearly matches the oracle communication complexity that has the knowledge of M.
[0057] In some embodiments, a key feature of the algorithm is interactivity: the norm of the mean estimated in the prior round is used to determine the sketch size in the next round. In some embodiments, at least two rounds of communication are necessary for any algorithm with instance-optimal communication complexity.
Example Adapting to the Instance's Tail-Norm
[0058] In some embodiments, a key aspect in norm-adaptive compression can be interactivity. On top of interactivity, another key aspect of tail-norm-adaptive compression can be count median-of-means sketching.
Count Median-of-means Sketching
[0059] In some embodiments, the sketching technique takes R? independent count-mean sketches. Let S.sup.(i):
.sup.d.fwdarw.
.sup.PC denote the i-th count-mean sketch. The sketching operation, S:
.sup.d.fwdarw.
.sup.R?PC, is defined as the concatenation:
[0060] An example median-of-means unsketching takes the median over R unsketched estimates: U(S(z)).sub.j=Median({(S.sup.(i).sup.
TABLE-US-00002 Algorithm 2 Adapt Tail FME Input: Sketch sizes R, C.sub.1, P, {tilde over (R)}, {tilde over (C)}, {tilde over (P)}, noise ?, {tilde over (?)}, {k.sub.j}.sub.j, { 3:
))]
4: Unsketch DP mean: Compute {tilde over (?)}.sub.j = Top.sub.k.sub.
= ||S.sub.j(
=
+ Laplace(2{tilde over (?)}) 7: If
?
, set
indicates data missing or illegible when filed
Approximately-sparse Setting
[0061] In some embodiments, improved guarantees can be achieved when ? is approximately sparse. This is captured by the tail-norm. Given z?.sup.d, ||z||?G, and k?[d], the normalized tail-norm is
This measures the error in the best k-sparse approximation. The mean estimate via count median-of-means sketch is still unbiased and for C=?Pk, has error bounded as,
[0062] If all tail-norms are known, then the sketch size can be set as
and achieve the error rate in Equation 1. When k=0, this recovers the previous communication complexity of count-mean sketch. For optimal choice of k, this communication complexity can be smaller.
Instance-optimal FME for Tail-norm
[0063] In some embodiments, multiple tail norms are used. To obtain the multiple tail norms, a doubling-trick based scheme can be used. Starting from an optimistic guess of the sketch size, the guess is progressively doubled , until an appropriate stopping criterion on the error is met. A challenge is in estimating the error of the first sketch S (for the current choice of sketch size), as naively this would require the uncompressed true vector z to be sent to the server. To this end, the second sketch {tilde over (S)} can be used to obtain a reasonable estimate of this error, while using significantly less communication than transmitting the original z.
[0064] The unsketched estimate can be re-sketched, U(S(z)), of z with {tilde over (S)}. The re-sketch can now be compared with the second sketch {tilde over (S)}(z) to get an estimate of the error: ||{tilde over (S)}(U(S(z)))?{tilde over (S)}(z)||.sup.2=S(U(S(z))?z)||.sup.2?||U(S(z)))?z||.sup.2, which uses the linearity and norm-preservation property of the count-sketch.
[0065] The client protocol can remain the same; each participating client sends two independent (count median-of-mean) sketches to the server. The server, starting from an optimistic choice of initial sketch size PC.sub.1, obtains the aggregated sketches from the clients via SecAgg and adds noise to the first sketch for DP. It then unsketches the first sketch to get an estimate of mean,
Example Federated Optimization/Learning
[0066] One example algorithm is as follows:
TABLE-US-00003 Algorithm 3 Adapt Norm FL Input: Sketch sizes L.sub.1 = RPC.sub.1 and ))]
4: Unsketch DP mean:
= ||
= 20?.sup.2,
(0, ?.sup.2 B.sup.2) 8:
indicates data missing or illegible when filed
[0067] In some embodiments, procedures for FME require multiple rounds of communication. While this can be achieved in FO by freezing the model for the interactive FME rounds, it is undesirable in a real-world application of FL as it would increase the total wall-clock time of the training process. Therefore, in some embodiments, an additional heuristic for the Adapt Norm and Adapt Tail algorithms can be added when used in FO/FL: to use a stale estimate of the privately estimated statistic using the second sketch.
Two Stage Method (e.g., Algorithm 4)
[0068] First, a single fixed compression can be used. It can be assumed that the norm of the updates remains relatively stable throughout training. To estimate it, W warm-up rounds can be run as the first stage. Then, using this estimate, a fixed compression rate can be used, by balancing the errors incurred due to compression and privacy, akin to Adapt Norm in FME, which can then be used for the second stage. Because the first stage is run without compression, it is important that W is minimized , which may be possible through prior knowledge of the statistic, e.g., proxy data distributions or other hyperparameter tuning runs.
TABLE-US-00004 Algorithm & Two Stage FL Input: Sketch sizes L.sub.1 = RPC.sub.1 and L = {tilde over (R)}{tilde over (P)}{tilde over (C)} noise niltiplier ?, model dimension d, adaptation method adapt, a constant c.sub.0, clipping threshold B, rounds K, 1: for j = 1 to W do 2: ?[clip.sub.B(S.sub.j.sup.(i)(z.sub.c.sup.(j)))]
3:
}.sub.c=1.sup.n)|| 4: Unsketch DP mean:
(0, ?.sup.2 B.sup.2/0.1) 6: end for 7:
indicates data missing or illegible when filed
Adapt Norm (e.g., Algorithm 3)
[0069] In some embodiments, an algorithm at every round uses two sketches: one to estimate the mean for FL and the other to compute an estimate of its norm, which is used to set the (first) sketch size for the next round. This is akin to the corresponding FME with the exception that it uses stale estimate of norm, from the prior round, to set the sketch size in the current round. Further, in some embodiments, the privacy budget between the mean and norm estimation parts can be split heuristically in the ratio 9:1, and set sketch size parameters {tilde over (R)}=1, R=P={tilde over (P)}=[log d]. Finally, the constant c.sub.0 is set such that the total error in FME, at every round, is at most 1.1 times the DP error.
[0070] In some embodiments, the Laplace noise added to norm can be replaced by Gaussian noise for ease of privacy accounting in practice.
Example Devices and Systems
[0071]
[0072] The client computing devices 102A, 102N can be any type of computing device, such as, for example, a personal computing device (e.g., laptop or desktop), a mobile computing device (e.g., smartphone or tablet), a gaming console or controller, a wearable computing device, an embedded computing device, or any other type of computing device.
[0073] The client computing devices 102A, 102N include one or more processors 112A, 112N and a memory 114A, 114N. The one or more processors 112A, 112N can be any suitable processing device (e.g., a processor core, a microprocessor, an ASIC, an FPGA, a controller, a microcontroller, etc.) and can be one processor or a plurality of processors that are operatively connected. The memory 114A, 114N can include one or more non-transitory computer-readable storage media, such as RAM, ROM, EEPROM, EPROM, flash memory devices, magnetic disks, etc., and combinations thereof. The memory 114A, 114N can store data 116A, 116N and instructions 118A, 118N which are executed by the processors 112A, 112N to cause the client computing devices 102A, 102N to perform operations.
[0074] In some implementations, the client computing devices 102A, 102N can store or include one or more machine-learned models 120A, 120N. The one or more machine-learned models 120A, 120N can be local machine-learned models 121A, 121N that are stored locally on the client computing devices 102A, 102N and are processing some data that is stored locally on the client computing devices 102A, 102N. For example, the machine-learned models 120 can be or can otherwise include various machine-learned models such as neural networks (e.g., deep neural networks) or other types of machine-learned models, including non-linear models and/or linear models. Neural networks can include feed-forward neural networks, recurrent neural networks (e.g., long short-term memory recurrent neural networks), convolutional neural networks or other forms of neural networks. Some example machine-learned models can leverage an attention mechanism such as self-attention. For example, some example machine-learned models can include multi-headed self-attention models (e.g., transformer models).
[0075] In some implementations, the one or more machine-learned models 120A, 120N can be received from the server computing system 130 over network 180, stored in the client computing device memory 114A, 114N, and then used or otherwise implemented by the one or more processors 112A, 112N. In some implementations, the client computing devices 102A, 102N can implement multiple parallel instances of a single machine-learned model 120A, 120N (e.g., to perform parallel classification across multiple instances of input).
[0076] Additionally, or alternatively, one or more machine-learned models 140 can be included in or otherwise stored and implemented by the server computing system 130 that communicates with the client computing device 102 according to a client-server relationship. In some instances, the one or more machine-learned models 140 can include a global model 145 having a plurality of parameters. For example, the machine-learned models 140 can be implemented by the server computing system 130 as a portion of a web service (e.g., an image classification service). Thus, one or more models 120 can be stored and implemented at the client computing device 102 and/or one or more models 140 can be stored and implemented at the server computing system 130.
[0077] The client computing devices 102A, 102N can also include one or more user input components 122A, 122N that receives user input. The user input component 122A can receive user input from a first user, and the user component 122N can receive user input from another user. For example, the user input component 122A, 122N can be a touch-sensitive component (e.g., a touch-sensitive display screen or a touch pad) that is sensitive to the touch of a user input object (e.g., a finger or a stylus). The touch-sensitive component can serve to implement a virtual keyboard. Other example user input components include a microphone, a traditional keyboard, or other means by which a user can provide user input.
[0078] The server computing system 130 includes one or more processors 132 and a memory 134. The one or more processors 132 can be any suitable processing device (e.g., a processor core, a microprocessor, an ASIC, an FPGA, a controller, a microcontroller, etc.) and can be one processor or a plurality of processors that are operatively connected. The memory 134 can include one or more non-transitory computer-readable storage media, such as RAM, ROM, EEPROM, EPROM, flash memory devices, magnetic disks, etc., and combinations thereof. The memory 134 can store data 136 and instructions 138 which are executed by the processor 132 to cause the server computing system 130 to perform operations.
[0079] In some implementations, the server computing system 130 includes or is otherwise implemented by one or more server computing devices. In instances in which the server computing system 130 includes plural server computing devices, such server computing devices can operate according to sequential computing architectures, parallel computing architectures, or some combination thereof.
[0080] As described above, the server computing system 130 can store or otherwise include one or more machine-learned models 140. For example, the models 140 can be or can otherwise include various machine-learned models. Example machine-learned models include neural networks or other multi-layer non-linear models. Example neural networks include feed forward neural networks, deep neural networks, recurrent neural networks, and convolutional neural networks. Some example machine-learned models can leverage an attention mechanism such as self-attention. For example, some example machine-learned models can include multi-headed self-attention models (e.g., transformer models).
[0081] In some instances, the computing devices/systems 102A, 102N, 130 can train the machine-learned models 120A, 120N and/or 140 stored at the client computing devices 102A, 102N and/or 140 using various training or learning techniques, such as, for example, backwards propagation of errors. For example, a loss function can be back propagated through the model(s) to determine updates to one or more parameters of the model(s) (e.g., based on a gradient of the loss function). Various loss functions can be used such as mean squared error, likelihood loss, cross entropy loss, hinge loss, and/or various other loss functions. Gradient descent techniques can be used to iteratively update the parameters over a number of training iterations.
[0082] In some implementations, performing backwards propagation of errors can include performing truncated backpropagation through time. The computing devices/systems 102, 130 can perform a number of generalization techniques (e.g., weight decays, dropouts, etc.) to improve the generalization capability of the models being trained.
[0083] In particular, the client computing device 102A, 102N can include training data 162A, 162N such as a local training dataset including a plurality of training examples. The training examples can be used in the federated learning approach described herein to train the models 120A, 120N, 140.
[0084] The network 180 can be any type of communications network, such as a local area network (e.g., intranet), wide area network (e.g., Internet), or some combination thereof and can include any number of wired or wireless links. In general, communication over the network 180 can be carried via any type of wired and/or wireless connection, using a wide variety of communication protocols (e.g., TCP/IP, HTTP, SMTP, FTP), encodings or formats (e.g., HTML, XML), and/or protection schemes (e.g., VPN, secure HTTP, SSL).
[0085] The machine-learned models described in this specification may be used in a variety of tasks, applications, and/or use cases.
[0086] In some implementations, the input to the machine-learned model(s) of the present disclosure can be image data. The machine-learned model(s) can process the image data to generate an output. As an example, the machine-learned model(s) can process the image data to generate an image recognition output (e.g., a recognition of the image data, a latent embedding of the image data, an encoded representation of the image data, a hash of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an image segmentation output. As another example, the machine-learned model(s) can process the image data to generate an image classification output. As another example, the machine-learned model(s) can process the image data to generate an image data modification output (e.g., an alteration of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an encoded image data output (e.g., an encoded and/or compressed representation of the image data, etc.). As another example, the machine-learned model(s) can process the image data to generate an upscaled image data output. As another example, the machine-learned model(s) can process the image data to generate a prediction output.
[0087] In some implementations, the input to the machine-learned model(s) of the present disclosure can be text or natural language data. The machine-learned model(s) can process the text or natural language data to generate an output. As an example, the machine-learned model(s) can process the natural language data to generate a language encoding output. As another example, the machine-learned model(s) can process the text or natural language data to generate a latent text embedding output. As another example, the machine-learned model(s) can process the text or natural language data to generate a translation output. As another example, the machine-learned model(s) can process the text or natural language data to generate a classification output. As another example, the machine-learned model(s) can process the text or natural language data to generate a textual segmentation output. As another example, the machine-learned model(s) can process the text or natural language data to generate a semantic intent output. As another example, the machine-learned model(s) can process the text or natural language data to generate an upscaled text or natural language output (e.g., text or natural language data that is higher quality than the input text or natural language, etc.). As another example, the machine-learned model(s) can process the text or natural language data to generate a prediction output.
[0088] In some implementations, the input to the machine-learned model(s) of the present disclosure can be speech data. The machine-learned model(s) can process the speech data to generate an output. As an example, the machine-learned model(s) can process the speech data to generate a speech recognition output. As another example, the machine-learned model(s) can process the speech data to generate a speech translation output. As another example, the machine-learned model(s) can process the speech data to generate a latent embedding output. As another example, the machine-learned model(s) can process the speech data to generate an encoded speech output (e.g., an encoded and/or compressed representation of the speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate an upscaled speech output (e.g., speech data that is higher quality than the input speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate a textual representation output (e.g., a textual representation of the input speech data, etc.). As another example, the machine-learned model(s) can process the speech data to generate a prediction output.
[0089] In some implementations, the input to the machine-learned model(s) of the present disclosure can be latent encoding data (e.g., a latent space representation of an input, etc.). The machine-learned model(s) can process the latent encoding data to generate an output. As an example, the machine-learned model(s) can process the latent encoding data to generate a recognition output. As another example, the machine-learned model(s) can process the latent encoding data to generate a reconstruction output. As another example, the machine-learned model(s) can process the latent encoding data to generate a search output. As another example, the machine-learned model(s) can process the latent encoding data to generate a re-clustering output. As another example, the machine-learned model(s) can process the latent encoding data to generate a prediction output.
[0090] In some implementations, the input to the machine-learned model(s) of the present disclosure can be statistical data. Statistical data can be, represent, or otherwise include data computed and/or calculated from some other data source. The machine-learned model(s) can process the statistical data to generate an output. As an example, the machine-learned model(s) can process the statistical data to generate a recognition output. As another example, the machine-learned model(s) can process the statistical data to generate a prediction output. As another example, the machine-learned model(s) can process the statistical data to generate a classification output. As another example, the machine-learned model(s) can process the statistical data to generate a segmentation output. As another example, the machine-learned model(s) can process the statistical data to generate a visualization output. As another example, the machine-learned model(s) can process the statistical data to generate a diagnostic output.
[0091] In some implementations, the input to the machine-learned model(s) of the present disclosure can be sensor data. The machine-learned model(s) can process the sensor data to generate an output. As an example, the machine-learned model(s) can process the sensor data to generate a recognition output. As another example, the machine-learned model(s) can process the sensor data to generate a prediction output. As another example, the machine-learned model(s) can process the sensor data to generate a classification output. As another example, the machine-learned model(s) can process the sensor data to generate a segmentation output. As another example, the machine-learned model(s) can process the sensor data to generate a visualization output. As another example, the machine-learned model(s) can process the sensor data to generate a diagnostic output. As another example, the machine-learned model(s) can process the sensor data to generate a detection output.
[0092] In some cases, the machine-learned model(s) can be configured to perform a task that includes encoding input data for reliable and/or efficient transmission or storage (and/or corresponding decoding). For example, the task may be an audio compression task. The input may include audio data and the output may comprise compressed audio data. In another example, the input includes visual data (e.g., one or more images or videos), the output comprises compressed visual data, and the task is a visual data compression task. In another example, the task may comprise generating an embedding for input data (e.g., input audio or visual data).
[0093] In some cases, the input includes visual data, and the task is a computer vision task. In some cases, the input includes pixel data for one or more images and the task is an image processing task. For example, the image processing task can be image classification, where the output is a set of scores, each score corresponding to a different object class and representing the likelihood that the one or more images depict an object belonging to the object class. The image processing task may be object detection, where the image processing output identifies one or more regions in the one or more images and, for each region, a likelihood that region depicts an object of interest. As another example, the image processing task can be image segmentation, where the image processing output defines, for each pixel in the one or more images, a respective likelihood for each category in a predetermined set of categories. For example, the set of categories can be foreground and background. As another example, the set of categories can be object classes. As another example, the image processing task can be depth estimation, where the image processing output defines, for each pixel in the one or more images, a respective depth value. As another example, the image processing task can be motion estimation, where the network input includes multiple images, and the image processing output defines, for each pixel of one of the input images, a motion of the scene depicted at the pixel between the images in the network input.
[0094] In some cases, the input includes audio data representing a spoken utterance and the task is a speech recognition task. The output may comprise a text output which is mapped to the spoken utterance. In some cases, the task comprises encrypting or decrypting input data. In some cases, the task comprises a microprocessor performance task, such as branch prediction or memory address translation.
[0095]
[0096]
[0097] The computing device 10 includes a number of applications (e.g., applications 1 through N). Each application contains its own machine learning library and machine-learned model(s). For example, each application can include a machine-learned model. Example applications include a text messaging application, an email application, a dictation application, a virtual keyboard application, a browser application, etc.
[0098] As illustrated in
[0099]
[0100] The computing device 50 includes a number of applications (e.g., applications 1 through N). Each application is in communication with a central intelligence layer. Example applications include a text messaging application, an email application, a dictation application, a virtual keyboard application, a browser application, etc. In some implementations, each application can communicate with the central intelligence layer (and model(s) stored therein) using an API (e.g., a common API across all applications).
[0101] The central intelligence layer includes a number of machine-learned models. For example, as illustrated in
[0102] The central intelligence layer can communicate with a central device data layer. The central device data layer can be a centralized repository of data for the computing device 50. As illustrated in
[0103]
[0104] Client devices 202 can each be configured to determine update vectors descriptive of updates to one or more trainable parameters associated with model 206 by training the model 206 on training data 208. For instance, training data 208 can be data that is respectively stored locally on the client devices 202. The training data 208 can include audio files, image files, video files, a typing history, location history, and/or various other suitable data. In some implementations, the training data can be any data derived through a user interaction with a client device 202.
[0105] The client devices 202 can compress the update vectors using the compression parameter(s) to generated compressed vectors. The client devices 202 can transmit the compressed vectors to the server 204.
[0106] Once the server 204 has received the compressed vectors to one or more trainable parameters from the client devices 202, the server 204 can aggregate (e.g., federated averaging, Secure Aggregation, etc.) the updates. For example, the server 204 can aggregate the compressed vectors and then de-compress the aggregate (or vice versa, depending on the compression algorithm used. Subsequently, the server 204 can modify one or more parameters of the model 206 based on the decompressed aggregate of the compressed update vectors.
[0107] Further to the descriptions above, a user may be provided with controls allowing the user to make an election as to both if and when systems, programs or features described herein may enable collection, storage, and/or use of user information (e.g., training data 208), and if the user is sent content or communications from a server. In addition, certain data may be treated in one or more ways before it is stored or used, so that personally identifiable information is removed. For example, a user's identity may be treated so that no personally identifiable information can be determined for the user, or a user's geographic location may be generalized where location information is obtained (such as to a city, ZIP code, or state level), so that a particular location of a user cannot be determined. Thus, the user may have control over what information is collected about the user, how that information is used, and what information is provided to the user.
[0108] Although training data 208 is illustrated in
[0109] Client devices 202 can be configured to provide the local updates (e.g., update vectors in compressed form) to server 204. As indicated above, training data 208 may be privacy sensitive. According to techniques described herein, the local updates can be performed and provided to server 204 without compromising the privacy of training data 208. For instance, in such implementations, training data 208 is not provided to server 204. The local updates do not include training data 208. In some implementations one or more encryption techniques, random noise techniques (e.g., differential privacy techniques), and/or other security techniques can be added to the training process to obscure any inferable information.
[0110] As indicated above, server 204 can receive each local update (e.g., update vectors in compressed form) from client device 202, and can aggregate the local updates to determine a global update to the model 206. In some implementations, server 204 can determine an average (e.g., a weighted average) of the local updates and determine the global update based at least in part on the average.
[0111] In some implementations, updated parameters are provided to the server 204 by a plurality of client devices 202, and the respective updated parameters are summed across the plurality of client devices 202. The sum for each of the updated parameters may then be divided by a corresponding sum of weights for each parameter as provided by the clients to form a set of weighted average updated parameters. In some implementations, updated parameters are provided to the server 204 by a plurality of client devices 202, and the respective updated parameters scaled by their respective weights are summed across the plurality of clients to provide a set of weighted average updated parameters. In some examples, the weights may be correlated to a number of local training iterations or epochs so that more extensively trained updates contribute in a greater amount to the updated parameter version. In some examples, the weights may include a bitmask encoding observed entities in each training round (e.g., a bitmask may correspond to the indices of embeddings and/or negative samples provided to a client).
[0112] In some implementations, satisfactory convergence of the machine-learned models can be obtained without updating every parameter with each training iteration. In some examples, each training iteration includes computing updates for a target set of trainable parameters.
[0113] In some implementations, scaling or other techniques can be applied to the local updates to determine the global update. For instance, a local step size can be applied for each client device 202, the aggregation can be performed proportionally to various data partition sizes of client devices 202, and/or one or more scaling factors can be applied to the local and/or aggregated updates. It will be appreciated that various other techniques can be applied without deviating from the scope of the present disclosure.
[0114] The update vectors may include information indicative of the updated trainable parameters. The update vectors may include the locally updated trainable parameters (e.g., the updated parameters or a difference between the updated parameter and the previous parameter received from the server 204). In some examples, the update vectors may include an update term, a corresponding weight, and/or a corresponding learning rate, and the server may determine therewith an updated version of the corresponding trainable parameter. Communications between the server 204 and the client devices 204 can be encrypted or otherwise rendered private.
[0115] In general, the client devices 202 may compute local updates to trainable parameters periodically or continually. The server may also compute global updates based on the provided client updates periodically or continually. In some implementations, the learning of trainable parameters includes an online or continuous machine-learning algorithm. For instance, some implementations may continuously update trainable parameters within the global model.
Example Methods
[0116]
[0117]
[0118] At 404, the client can compute a first sketched vector Q_j using the current adaptive sketching operators S_j and a second sketched vector Q?_j using base sketching operators S?_j. At 406, the client can transmit the first sketched vector Q_j and the second sketched vector Q?_j. At 406, the server can receive the first sketched vector Q_j and the second sketched vector Q?_j. At 408, the server can determine a first aggregate v_j of the first sketched vectors Q_j.
[0119] At 410, the server can unsketch the first aggregate v_j to obtain a mean u. At 412, the server can determine a second aggregate v?_j of the second sketched vectors Q?_j. At 414, the server can estimate a norm N based on the second aggregate v?_j. At 416, the server can determine an update to adaptive sketching operators S_j based on the norm N. In some implementations, after 416, the method can return to 402.
[0120]
[0121] At 504, the client can compute a first sketched vector Q_j using the current adaptive sketching operators S_j and a second sketched vector Q?_j using base sketching operators S?_j. At 506, the client can transmit the first sketched vector Q_j and the second sketched vector Q?_j. At 506, the server can receive the first sketched vector Q_j and the second sketched vector Q?_j. At 508, the server can determine a first aggregate v_j of the first sketched vectors Q_j. At 510, the server can determine a second aggregate v?_j of the second sketched vectors Q?_j. At 512, the server can unsketch the first aggregate v_j to determine a current mean estimate u_j.
[0122] Referring now to
[0123] If the noised version e?_j of the error e_j is less than a noised version y?_j of a current threshold value ?_j, then the method can proceed to 520 and set the current mean estimate u_j as a final mean estimate u. The method can then end.
[0124] However, if noised version e?_j of the error e_j is greater than a noised version ??_j of a current threshold value ?_j, then the method can proceed to 522 and increase (e.g., double) the current adaptive sketching operators S_j. After 522, method 500 can return to 502.
[0125]
Additional Disclosure
[0126] The technology discussed herein makes reference to servers, databases, software applications, and other computer-based systems, as well as actions taken, and information sent to and from such systems. The inherent flexibility of computer-based systems allows for a great variety of possible configurations, combinations, and divisions of tasks and functionality between and among components. For instance, processes discussed herein can be implemented using a single device or component or multiple devices or components working in combination. Databases and applications can be implemented on a single system or distributed across multiple systems. Distributed components can operate sequentially or in parallel.
[0127] While the present subject matter has been described in detail with respect to various specific example embodiments thereof, each example is provided by way of explanation, not limitation of the disclosure. Those skilled in the art, upon attaining an understanding of the foregoing, can readily produce alterations to, variations of, and equivalents to such embodiments. Accordingly, the subject disclosure does not preclude inclusion of such modifications, variations and/or additions to the present subject matter as would be readily apparent to one of ordinary skill in the art. For instance, features illustrated or described as part of one embodiment can be used with another embodiment to yield a still further embodiment. Thus, it is intended that the present disclosure covers such alterations, variations, and equivalents.