METHOD AND APPARATUS FOR REAL-WORLD CROSS-MODAL RETRIEVAL PROBLEMS
20230154159 · 2023-05-18
Inventors
Cpc classification
G06V10/44
PHYSICS
G06V10/762
PHYSICS
G06V10/774
PHYSICS
G06V20/70
PHYSICS
G06N3/008
PHYSICS
G06F18/256
PHYSICS
International classification
G06V10/774
PHYSICS
G06V10/44
PHYSICS
G06V10/762
PHYSICS
Abstract
Broadly speaking, the present application generally relates to a method for training a machine learning, ML, model to perform real world cross-modal retrieval problems, and to a computer-implemented method and apparatus for performing real world cross-modal retrieval problems such as including text-based video retrieval, sketch-based image retrieval, and image-text retrieval using a trained machine learning, ML, model.
Claims
1. A computer-implemented method for training a machine learning, ML, model comprising a first feature extractor for extracting image features from an image and a second feature extractor, the method comprising: obtaining a dataset comprising a plurality of pairs of data instances, with each pair comprising a first data instance having a first modality and a second data instance having a second modality, wherein the first data instance is an image; evaluating, using the second feature extractor, at least some of the plurality of second data instances to extract a set of features for each of the evaluated second data instances; assigning a first set of class labels to the plurality of first data instances based on the extracted set of features for the second data instances; training the first feature extractor using the assigned first set of class labels; evaluating, using the first feature extractor, at least some of the plurality of first data instances to extract a set of image features for the evaluated first data instances; assigning a second set of class labels to the plurality of second data instances based on the extracted set of image features; and training the second feature extractor using the assigned second set of class labels.
2. The method of claim 1, wherein assigning the first set of class labels comprises: determining the clusters of second data instances based on the extracted set of features; and assigning each pair of data instances in a cluster the same class label.
3. The method of claim 1, wherein assigning the second set of class labels comprises: determining the clusters of second data instances based on the extracted set of image features; and assigning each pair of data instances in a cluster the same class label.
4. The method of claim 1, wherein training the first feature extractor and training the second feature extractor uses cross-entropy minimization.
5. The method of claim 1, wherein training the first and second feature extractors comprises defining a first linear classifier for the first modality, a second linear classifier for the second modality and a set of training parameters which are shared between the first and second modalities.
6. The method of claim 5, wherein the first and second linear classifiers are denoted by p(y|x.sup.A) and p(y|x.sup.B) and are defined as,
7. The method of claim 5, further comprising: using optimization to obtain the first and second linear classifiers.
8. The method of claim 7, further comprising: estimating a first surrogate for the first linear classifier using the current second linear classifier; and estimating a second surrogate for the second linear classifier using the current first linear classifier.
9. The method of claim 8, wherein training the first feature extractor comprises updating the first feature extractor using the estimated first and second surrogates.
10. The method of claim 8, wherein training the second feature extractor comprises updating the second feature extractor using the estimated first and second surrogates.
11. The method of claim 8, wherein the estimate for the first surrogate q(y|x.sup.A) is found using an optimization defined as
12. The method of claim 11, further comprising: solving the optimizations to estimate the first and second surrogates using the Sinkhorn-Knopp (SK) algorithm.
13. The method of claim 11, further comprising: selecting a batch comprising a plurality of pairs of data instances from the dataset; evaluating the first and second feature extractors for the selected batch; and storing the evaluations in at least one queue and performing the optimization on the queued evaluations.
14. The method of claim 1, further comprising: iterating each of the evaluating, assigning and training steps for separate batches of data.
15. An apparatus for training a machine learning, ML, the apparatus comprising: a memory storing one or more instructions; a machine learning model; and at least one processor configured to execute the one or more instructions stored in the memory to: obtain a dataset comprising a plurality of pairs of data instances, with each pair comprising a first data instance having a first modality and a second data instance having a second modality, wherein the first data instance is an image, evaluate, using the second feature extractor, at least some of the plurality of second data instances to extract a set of features for each of the evaluated second data instances, assign a first set of class labels to the plurality of first data instances based on the extracted set of features for the second data instances, train the first feature extractor using the assigned first set of class labels, evaluate, using the first feature extractor, at least some of the plurality of first data instances to extract a set of image features for the evaluated first data instances, assign a second set of class labels to the plurality of second data instances based on the extracted set of image features, and train the second feature extractor using the assigned second set of class labels.
16. The apparatus of claim 15, wherein, to assign the first set of class labels, the at least one processor is further configured to: determine the clusters of second data instances based on the extracted set of features, and assign each pair of data instances in a cluster the same class label.
17. The apparatus of claim 15, wherein, to assign the second set of class labels, the at least one processor is further configured to: determine the clusters of second data instances based on the extracted set of image features, and assign each pair of data instances in a cluster the same class label.
18. The apparatus of claim 15, wherein, to train the first and second feature extractors, the at least one processor is further configured to: define a first linear classifier for the first modality, a second linear classifier for the second modality and a set of training parameters which are shared between the first and second modalities.
19. The apparatus of claim 18, wherein the first and second linear classifiers are denoted by p(y|x.sup.A) and p(y|x.sup.B) and are defined as,
20. A non-transitory computer-readable recording medium having recorded thereon a program for executing, the method of claim 1.
Description
BRIEF DESCRIPTION OF THE DRAWINGS
[0050] The above and other features, aspects, and advantages of certain embodiments of the present disclosure will be more apparent from the following detailed description, taken in conjunction with the accompanying drawings in which like characters represent like parts throughout the drawings, and in which:
[0051]
[0052]
[0053]
[0054]
[0055]
[0056]
[0057]
[0058]
[0059]
[0060]
[0061]
[0062]
[0063]
[0064]
[0065]
[0066]
[0067]
DETAILED DESCRIPTION
[0068] For the purpose of promoting an understanding of the principles of the disclosure, reference will now be made to various example embodiments illustrated in the drawings and specific language will be used to describe the same. It will nevertheless be understood that no limitation of the scope of the disclosure is thereby intended, such alterations and further modifications in the illustrated system, and such further applications of the principles of the disclosure as illustrated therein being contemplated as would normally occur to one skilled in the art to which the disclosure relates.
[0069] It will be understood by those skilled in the art that the foregoing general description and the following detailed description are merely illustrative of the present disclosure and are not intended to be restrictive thereof.
[0070] Reference throughout this disclosure to “an aspect”, “another aspect” or similar language may refer, for example, to a particular feature, structure, or characteristic described in connection with an embodiment being included in at least one embodiment of the present disclosure. Thus, appearances of the phrase “in an embodiment”, “in another embodiment” and similar language throughout this disclosure may, but do not necessarily, all refer to the same embodiment.
[0071] The terms “comprises”, “comprising”, or any other variations thereof, are intended to cover a non-exclusive inclusion, such that a process or method that comprises a list of steps does not include only those steps but may include other steps not expressly listed or inherent to such process or method. Similarly, one or more devices or sub-systems or elements or structures or components proceeded by “comprises... a” does not, without more constraints, preclude the existence of other devices or other sub-systems or other elements or other structures or other components or additional devices or additional sub-systems or additional elements or additional structures or additional components.
[0072] Unless otherwise defined, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skilled in the art to which this disclosure belongs. The system, methods, and examples provided herein are illustrative only and not intended to be limiting.
[0073] Broadly speaking, the present techniques generally relate to a method for training a machine learning, ML, model to perform real world cross-modal retrieval problems, and to a computer-implemented method and apparatus for performing real world cross-modal retrieval problems such as including text-based video retrieval, sketch-based image retrieval, and image-text retrieval using a trained machine learning, ML, model. The training method has a reduced sampling complexity and avoids the potentially wrong assumption that instances from different pairs are automatically irrelevant. The method finds similar instances from other pairs, and the feature extractor is trained in such a way that the same-class instances, even in different pairs, are well aligned.
[0074] As explained in more detail below, the method uses class prediction to learn the feature extractors, independently for each modality. The class prediction p can be expressed as
where Φ.sub.I is a first feature extractor for first data instances having a first modality (in this example images I having an image modality), Φ.sub.T is a second feature extractor for second data instances having a second modality (in this example, text captions T having a text modality) and y is a class label. In other words, each modality is independent and there are no pairwise terms which reduces the complexity to O(N).
[0075]
[0076]
[0077] such that Φ.sub.I(I) ≈ Φ.sub.T(T) if (I,T) is paired, and vice versa [0078] where Φ.sub.I is the feature extractor for the images I and Φ.sub.T is the feature extractor for the text T.
[0079] In the cross-modal retrieval problem, the training is typically only supervised by the relevant multi-modal pairs in the data. The contrastive learning approach is the most popular approach for this task. The contrastive learning approach aims to learn the cross-modal similarity measure by the intuitive criteria that pull together relevant pairs and push away irrelevant ones. However, its sampling complexity for learning is quadratic in the number of training data points. Moreover, it makes potentially wrong assumption that the instances in different pairs are automatically irrelevant.
[0080]
[0081]
[0082]
[0083]
[0084] The whole process of label prediction and supervised learning with swapped classes is alternated to learn the optimal feature extraction networks. For example, as shown in
[0085] The clusters may be identified using any suitable technique. The most traditional and popular clustering technique is known as K-means clustering. and is described for example at https://en.wikipedia.org/wiki/K-means_clustering. K-means essentially seeks the best grouping (i.e. cluster) of data points that minimises the sums of distances between data points within the groups/clusters. Another known technique is to regard the (unknown) cluster labels as variables in an optimization problem, and solve them. The objective function of the optimization problem typically measures the similarity of cohesiveness of data points that belong to the same cluster labels. It is often called self-labelling. The method described in this application typically follows this latter approach, which is described in more detail in “Self-labelling via simultaneous clustering and representation learning” by Asano et al published in International Conference on Learning Representations in 2020.
[0086] Then as shown in
where Φ.sub.T is a second feature extractor for second data instances T, p is the class prediction, and y is the assigned class label.
[0087] As shown in
where Φ.sub.I is the first feature extractor for first data instances I, p is the class prediction, and y is the assigned class label. This approach which is schematically illustrated in
[0088]
[0089] The apparatus 200 is connected to a database 250 which may be local or remote from (i.e. at a different location from) the apparatus 200. The database 250 may store training data which is used when training the machine learning model 210. The database 250 may also store the output from the training process, including for example the class labels, the trained machine learning model and intermediate data which is generated in the process of training, e.g. the surrogates described below.
[0090] The training data comprises data instances from modality A and modality B, which may be represented by x.sup.A and x.sup.B, respectively. For instance, x.sup.A is an image from the image modality, while x.sup.B is a text/caption from the text modality. Throughout the description we deal with modality-wise feature representation, meaning that we have modality-wise feature extractors (neural networks) Φ.sup.A(.Math.) and Φ.sup.B (.) applied to x.sup.A and x.sup.B, respectively. Thus, as shown in
[0091] The goal is to learn the feature extractors so that the relevant pairs x.sup.A and x.sup.B have a high similarity score s(x.sup.A,x.sup.B), while irrelevant pairs have a low similarity score. The main benefit of the modality-wise feature representation is the computational efficiency, scalable to billions of instances at training/test time, thanks to the efficient dot-product.
[0092] The training data are composed of relevant pairs
where
and
are the instances in the i-th relevant pair. At test time, a query is given from the query modality, say x.sup.A, and the goal is to find the most relevant instance, say x.sup.B, from the other modality, where the search is performed on the given test set
[0093] Our idea is to introduce (latent) semantic class labels for data instances and use them to learn the feature extractors. The class labels supposedly decide the relevance of data instances from different modalities, that is, x.sup.A and x.sup.B are considered relevant if their class labels are the same, and vice versa. Obviously, the paired cross-modal instances in the training data must have the same class labels. But beyond this, instances from different pairs can also be deemed relevant if they belong to the same semantic class labels. The motivation is that if we estimate the class labels accurately, the feature extractor learning can be turned into a supervised classification problem of linear sampling complexity.
[0094] More formally, we consider (unknown) class labels to be assigned to the data instances. y.sup.A, y.sup.B ∈ {1, ..., K} are assigned as the class labels for x.sup.A and x.sup.B, respectively, where K is chosen by the user. The relevance of x.sup.A and x.sup.B is determined by their class labels: x.sup.A and x.sup.B are deemed relevant if y.sup.A = y.sup.B and irrelevant if y.sup.A ≠ y.sup.B. It is noted that if the class labels that bear such semantics were known in the training data, then training becomes supervised learning that can be done for each modality. Such supervised learning would allow us to avoid pairwise terms in the loss function, leading to linear sampling complexity. However, the class labels are not known, and thus it necessary to optimize them (i.e., using self-supervised learning) together with the feature extractors Φ.sup.A (.Math.) and Φ.sup.B (.Math.).
[0095] As shown in
where P = {p.sub.1, ..., p.sub.K} are trainable parameters that are shared between the two modalities, Φ.sup.M is the feature extractor, x.sup.M is the data instance, y is the class label for class j, each p.sub.j may be regarded as the prototype vector for class j that lies in the shared feature space, and τ is the temperature in the softmax. The softmax function is a well known term of art which is a smooth approximation to the arg max function and may also be termed the normalized exponential function. Using the softmax function may be beneficial when compared to the argmax function because the objective remains differentiable, allowing gradient backpropagation to be used to update the model parameters. The softmax function can be used to convert values into probabililties and in this case, the value
corresponds to the expected class label after evaluating Φ.sup.M(x). The temperature (τ) has the impact of highlighting or attenuating the current model’s decision score
on class prediction. Properly choosing τ is important because if τ is too large, then model’s current class label prediction would be unused (ignored) in training; if τ is too small, we would rely too much on model’s current decision which might be incorrect at early training stage, potentially leading to an inaccurate retrieval model.
[0096] As shown in
with respect to the trainable parameters P and the network parameters of Φ.sup.A(.Math.) (similarly for modality B). There is no access to p.sub.true(y|x.sup.A), and thus one may be tempted to use the model for the first linear classifier p(y|x.sup.A) in (2) instead. However, it can easily lead to a degenerate solution such as the one that puts all the probability mass on a particular single class all the time (thus attaining the optimal cross-entropy loss 0). Moreover, this would make learning Φ.sup.A(.Math.) and Φ.sup.B(.Math.) nearly independent and less interacted with each other, merely through the shared trainable parameters P.
[0097] Instead of using cross-entropy loss minimization, the optimization module 206 may use a different technique which includes using first and second surrogates 232, 234 for the first and second linear classifiers. An optimization problem is formed to estimate a first surrogate 232 which is an estimate of the true conditional class distribution for modality A, p.sub.true(y|x.sup.A), and a second surrogate 234 which is an estimate of the true conditional class distribution for modality B, p.sub.true(y|x.sup.B). The first surrogate 232 is denoted by q(y|x.sup.A) and the second surrogate 234 is denoted by q(y|x.sup.B). The first surrogate 232 is estimated using the information from the other (second) modality B, while imposing additional constraints to avoid the degenerate solutions. Similarly, the second surrogate 234 is estimated using the information from the other (first) modality A, while imposing additional constraints to avoid the degenerate solutions.
[0098] More specifically, we optimize the first and second surrogates with the following two criteria. First, each surrogate for the class distribution for one modality needs to be well aligned with the current estimate for the linear classifier in the other modality for each paired data instance. In other words, q(y|x.sup.A) needs to be well aligned with the current estimate of p(y|x.sup.B) for x.sup.B that is paired with x.sup.A and similarly, q(y|x.sup.B) needs to be well aligned with the current estimate of p(y|x.sup.A) for x.sup.A that is paired with x.sup.B. This is due to the aforementioned requirements for the class labels, where the class labels (more generally, their distributions) of the paired instances should match. Secondly, the marginal distribution
is constrained to be a uniform distribution. This constraint naturally arises from the symmetry of class labels, a reasonable assumption about the true class distribution, and successfully leaves out the degenerate solutions discussed above.
[0099] To summarize, the following is the optimization problem for the surrogate q(y|x.sup.A),
where Q.sup.A is an (N × K) matrix, N is the number of data instances x.sub.i.sup.A in the dataset D, K is the number of class labels y,
is the current estimate of the probability for x.sub.i.sup.B that is paired with x.sub.i.sup.A of the class label having a value of j.
[0100] We perform the similar optimization for the surrogate q(y|x.sup.B) to approximate p.sub.true(y|x.sup.B) by exchanging the roles of A and B. In other words,
where Q.sup.B is an (N × K) matrix, N is the number of data instances x.sub.i.sup.B in the dataset D, K is the number of class labels y,
is the current estimate of the probability of the class label having a value of j for x.sub.i.sup.A that is paired with x.sub.i.sup.B.
[0101] The optimal solutions (surrogates) are denoted by q.sup.A and q.sup.B, where we use the superscript to distinguish the two modalities. Note that during the optimization of (3) for q.sup.A and the optimization of (4) for q.sup.B, we fix the model parameters, that is, P and the feature extractor networks. Hence the overall optimization is alternation between: i) the surrogate optimizations (3) and (4) with P, Φ.sup.A, Φ.sup.B fixed, and ii) the supervised (cross-entropy) loss minimization with q.sup.A and q.sup.B fixed. The supervised loss minimization L.sub.S can be written as (subscript s stands for SwAMP):
where P = {p.sub.1, ..., p.sub.K} are trainable parameters that are shared between the two modalities A and B, Φ.sup.A is the feature extractor for the modality A, Φ.sup.B is the feature extractor for the modality B, q.sup.A(y|x.sup.A) and q.sup.B(y|x.sup.B) are the surrogates for the first and second linear classifiers p(y|x.sup.A) and p(y|x.sup.B), y is the label and x.sub.i.sup.A and x.sub.i.sup.B is the ith elements for the two modalities in the dataset D.
[0102] Optimizing (3) and (4) may be achieved by a variation of the optimal transport (OT) problem which is described for example in “Optimal Transport: Old and New” by Villiani published by Springer in 2008. For optimizing (3) and (4), the cost matrices may be expressed as
[0103] The marginal constraints may be expressed as
[0104] Although the OT is known to be an instance of the linear program (LP), conventional LP solvers are not suitable for large-scale problems. As is common practice, we relax the problem by augmenting the loss with the entropic regularizer which are added to the loss (thus, penalizing small entropy). The augmented loss may be expressed as followed for q(y|x.sup.A) and q(y|x.sup.B), respectively:
where
and
are the marginal constraints defined above and η is the regularization trade-off hyperparameter.
[0105] The augmented loss may be solved by the efficient Sinkhorn-Knopp (SK) algorithm which is described for example in “Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances” by Cuturi published in Advances in Neural Information Processing Systems in 2013. Here η is the regularization trade-off hyperparameter. The SK algorithm finds the optimal solutions as
where
and the vectors
and
are the fixed points of
where
and the vectors
and
are the fixed points of
[0106] The fixed point iteration usually converges quickly after a few iterations. We denote the algorithm as:
[0107] One challenge in optimizing (3) and (4) with the SK, however, is that it involves the entire dataset D in the loss, which means that the second part of the overall optimization defined in (5) has to be deferred until q.sup.A and q.sup.B are optimized for an entire data epoch. Simply replacing D with a minibatch might be dangerous since the population class marginal distributions are poorly covered by a minibatch. We need an even larger subset of D to roughly meet the (uniform) class constraint. To this end, we adopt the (FIFO) queues 206 shown in
[0108] To have the queues 206 filled with the latest features, we insert the features of the current minibatch into the queues 206, then perform the SK algorithm. Once (3) and (4) are done, we can optimize (5) by gradient descent, but only the current minibatch portion of q is used. The final loss function may be a combination of the SwAMP loss and contrastive loss:
where λ is the trade-off hyperparameter, L.sub.c is the contrastive loss, L.sub.s ius the SwAMP loss defined in (5), P = {p.sub.1, ..., p.sub.K} are trainable parameters that are shared between the two modalities A and B, ϕ.sup.A is the feature extractor for the modality A and Φ.sup.B is the feature extractor for the modality B.
[0109] The contrastive (or triplet loss) learning is described for example in “Learning a similarity metric discriminatively, with application to face verification” by Chopra et al published in IEEE Conference on Computer Vision and Pattern Recognition in 2005 and “Dimensionality Reduction by Learning an Invariant Mapping” by Hadsell et al published in IEEE Conference on Computer Vision and Pattern Recognition in 2006. The loss function penalizes small similarity scores for relevant pairs, and penalizes large similarity scores for irrelevant pairs. With the introduction of the margin and considering the most violating irrelevant pairs (i.e., hard negative mining), the loss L.sub.c can be formally written as (subscript c stands for contrastive):
where (z).sub.≥α = max(0, α - z) only incurs positive loss when z < α, α is the margin (e.g., 0.2),
is the similarity between the ith data instances in the modalities A and B.
[0110] As explained above, the surrogate q.sup.A is estimated using the current classification model in modality B, and vice versa, so the class assignment is swapped. Therefore we name this approach SwAMP (Swapped Assignment of Multi-modal Pairs). The pseudo code of the SwAMP is described in the algorithm shown below.
[0111] Input: Class cardinality K, queue size, softmax temperature τ, regularizer trade-off η in SK.
[0112] Initialize: Prototypes P and ϕ.sup.A(.Math.), ϕ.sup.B(.Math.). Empty the queue Q.
[0113] Output: Trained model {P, ϕ.sup.A(.Math.), ϕ.sup.B(.Math.)}.
[0114] Repeat: [0115] 1. Sample a minibatch of paired data [0116] 2. Evaluate and for i ∈ B (forward pass). [0117] 3. Insert into the queue Q. [0118] 4. Solve (3) & (4) for modality A and B: [0119] 5. Take the minibatch portions {q.sup.A(y|i), q.sup.B (y|i)}.sub.i∈B, and do one SGD update with L in (7).
[0120]
is obtained (step S100). As illustrated in the pseudo code and as explained above, the entire dataset may be processed in batches and thus there is an optional step of obtaining a batch of paired data instances from the data set (step S102). In a forward pass of the machine learning model, the feature extractors ϕ.sup.A(x.sup.A) and ϕ.sup.B(x.sup.B) are evaluated for the data (i.e. for the batch of data where batches are being used) at step S104. Evaluating the feature extractors may comprise calculating a prototype attention representation (PAR) as described in more detail below.
[0121] Evaluating the data means that the features are extracted using the appropriate feature extractor. The data in modality A may be a single image or a set of images forming a video. Evaluating the data in modality A thus comprises processing the image to identify features therein and a feature space such as the feature space 310 shown in
[0122] Using the evaluations, it is possible to cluster the evaluated data in each modality which is shown in the separate branches and then update the feature extractor for the other modality. Thus at step S110, the data x.sup.A in the first modality (A) is clustered using the feature representations for the data evaluated by the corresponding feature extractor ϕ.sup.A. Each cluster is assigned a class label y.sup.A. Based on the clustering for the data x.sup.A in the first modality and the class label for each cluster, class labels y.sup.B are assigned to the data x.sup.B in the second modality (B) at step S112. In other words, as shown in
in the first modality is assigned to the paired data instance
in the second modality. As explained above, the number of labels (and hence the number of clusters) can be selected by the user. The impact of the number of labels is shown in
[0123] In a similar manner, at step S120, the data x.sup.B in the second modality (B) is clustered using the feature representations for the data evaluated by the corresponding feature extractor ϕ.sup.B. Based on the clustering for the data x.sup.B, class labels y.sup.A are assigned to the data x.sup.A in the first modality (A) at step S122. As explained above, the number of labels can be selected by the user and may be the same or different to the number of labels used in step S112. The next step is to update the feature extractor Φ.sup.A for the first modality (A) at step S124.
[0124] After the feature extractors have been updated in steps S114, S124, the method loops back to perform an iterative update, e.g. using a different batch where batches are being used.
[0125]
[0126] As explained above, the optimization problem of step S152 may be solved using the Sinkhorn-Knopp (SK) algorithm. Batches of data are processed to improve the optimization process and thus at step S154, a batch of paired data instances are selected from the dataset and the feature extractors Φ.sup.A, ϕ.sup.B are evaluated for this batch of data. The evaluations of the feature extractors are stored in queues (FIFO). At step S154, the updated feature extractors Φ.sup.A, ϕ.sup.B from step S180 may be used.
[0127] Thus at step S160, the embeddings ϕ.sup.A(x.sup.A) from the latest batch are stored in a queue with other data from earlier minibatches and similarly, at step S170, the embeddings ϕ.sup.B(x.sup.B) from the latest batch are stored in a queue. Using the queue data, it is possible to optimise q.sup.A and q.sup.B and as shown at steps S162 and S172, the optimization is done using p(y|x.sup.A) and p(y|x.sup.B).
[0128] Once q.sup.A and q.sup.B have been obtained, a loss function may be used to optimise the trainable parameters P together with the feature extractors Φ.sup.A, ϕ.sup.B. An example of a suitable loss function is the cross-entropy loss minimisation shown in equation (5). This optimization of step S180 may be achieved using gradient descent and uses only the current minibatch portion of q.sup.A and q.sup.B. There is a step S182 of determining whether there are more batches of data to be processed and if so, the process reiterates through steps S154 to S180.
[0129] Once there are no more batches to process, the final loss function may be optimised at step S184. As explained above, this optimization may use equation (7) which includes the contrastive loss function but this is an optional part of the equation which may be omitted although at least in some circumstances, including the contrastive loss function may improve the overall result. As an alternative to including the contrastive loss function, other known loss functions could be augmented with the SwAMP loss function. As described in the sketch based image retrieval example, the final loss function may include triplet loss, domain loss and semantic loss as described in the Doodle2Search paper by Dey et al. The final optimization optimises over the trainable parameters P and the feature extractors Φ.sup.A, ϕ.sup.B. The final trained model may then be output at the end of the process. The model may be output to another device, e.g. an apparatus or user device for using the model.
[0130] In an embodiment, a computer-implemented method for performing cross-modal retrieval using a trained machine learning, ML, model.
[0131] In an embodiment, wherein the cross-modal retrieval problem is selected from text-based video retrieval, sketch-based image retrieval and image-text retrieval.
[0132] We test the proposed SwAMP loss on several different types of real-world cross-modal retrieval problems: text-based video retrieval, sketch-based photo image retrieval, and image-text retrieval. For each problem/dataset, we choose the most popular and successful method in the literature, and replace its loss function (mostly contrastive loss) with the proposed SwAMP loss to demonstrate the performance improvement. To this end, for fair comparison, we faithfully follow the same optimization strategy and hyperparameters as the baseline methods. First we provide proof-of-concept synthetic experiments and ablation study on the hyperparameters of SwAMP.
Synthetic Data
[0133] In this section we devise a synthetic dataset not only for performing the proof-of-concept test of our SwAMP algorithm, but also analyze the impacts of the various hyperparameters and training options in the proposed algorithm. For the former, we especially focus on the retrieval performance improvement achieved by our SwAMP compared to the contrastive loss or its popular variants (e.g., online hard-example mining loss).
[0134] The dataset is constructed by the following procedure: We randomly generate 20 Gaussians in ℝ.sup.5, each of which is considered to represent a semantic class. For each Gaussian (class), we sample a latent vector z ∈ ℝ.sup.5, and a pair of instances (x.sup.A ∈ R.sup.100, x.sup.B ∈ ℝ.sup.100) is then generated by x.sup.A = f.sub.A(z) and x.sup.B = f.sub.B(z) where f.sub.A and f.sub.B are randomly initialized fully-connected DNNs with two hidden layers of 50 units. We generate 500 pairs for each class that leads to 10,000 data pairs, and split them into 7000/1000/2000 train/validation/test sets. The validation recall-at-1 (R@1) performance is evaluated at every training epoch, and the model at the epoch with the best validation performance is selected as the final model. Note that during training we only use the paired data (x.sup.A, x.sup.B) with the semantic class labels hidden to the training algorithms.
[0135] For training, we adopt the embedding networks ϕ.sup.A(x.sup.A) and ϕ.sup.B(x.sup.B) as fully-connected neural nets with two hidden layers of 50 units. The embedding dimension is chosen as 5. We train the model with this same network architecture, using the contrastive loss and our SwAMP loss. For both loss functions, the batch size is 128, and the Adam optimizer described in Adam: “A Method for Stochastic Optimization” by Kingma and Ba published in International Conference on Learning Representations 2015 is used with learning rate 10.sup.-3, and the maximum epoch is 100.
[0136] For the contrastive loss, we adopt the (online) hard-example mining with the margin parameter α = 0.1. For the SwAMP loss, the defaults parameters are as follows: temperature τ = 0.01 for the softmax classifier, the reciprocal impact of the max-entropy regularizer for the Sinkhorn-Knopp η = ⅟0.05 (i.e., we add the entropic regularizer with the weight η.sup.-1 = 0.05 to the objective of the OT problem). Also, by default, we choose the number of classes K = 1000 and the queue size 1,280, 10 times the batch size (and greater than K). For both loss functions, the embedding networks are initialized randomly.
[0137] For test, we perform the cross-modal retrieval task x.sup.A .fwdarw. x.sup.B, treating each x.sup.A in the test set as a query, retrieving x.sup.B from the test set. There are two ways to define the retrieval error: i) pair-based which treats the retrieved x.sup.B as a correct retrieval only if the query x.sup.A and the retrieved x.sup.B are found as a pair in the data, and ii) class-based which compares only the classes of the query x.sup.A and the retrieved x.sup.B. Hence the pair-based error is more strict than the class-based since it counts only the data item that appears in the data as correct retrieval, without comparing the semantic classes of the retrieved item and the query.
TABLE-US-00001 Retrieval results on the synthetic data Error type Method R@1 ↑ R@5 ↑ R@10 ↑ Med-R ↓ Pair-based Contrastive 84.10 98.60 99.55 1 SwAMP 90.80 99.95 100.0 1 Class-based Contrastive 91.60 99.70 99.90 1 SwAMP 95.70 99.95 100.0 1
[0138] In the table above, R@K is the % of a true item found in model’s top-K retrieved items and Med-R is the average median rank of a true item in the model’s ranking of all search items.
Ablation Study on Hyperparameters
[0139] There are several hyperparameters in our SwAMP model, and we have conducted several ablation-type study on the impacts of the hyperparameters. The hyperparameters that are deemed to be the most critical are: i) the number of classes K, ii) the size of the queues, iii) initialization of the feature extraction networks (either random initialization or pretrained one with the contrastive loss), iv) entropic regularization trade-off η in Sinkhorn-Knopp, and v) the soft/hard cluster assignment after OT clustering.
[0140] Number of classes (K). We vary the number of classes K from {200,500,1000,2000,3000}, and record the R@1 scores for both pair and class based error types for our SwAMP model. The results are shown in
[0141] Size of queues. Another important hyperparameter is the size of the queues, where the OT clustering is performed on the latest features that are stored in the queues. In addition to the default queue size 1280 = 10 × 128 (batch size), we try with different queue sizes {0,1,2,5,20} × 128. Note that the OT clustering is performed on the union of the features in the queue and the current batch, hence zero queue size implies that we only use the current batch for OT clustering. The results are reported in
[0142] Initialization of feature extractor networks. In our default setup, the feature extractor networks ϕ.sup.A(.Math.) and Φ.sup.B(.Math.) are initialized randomly. Now we test the performance of the SwAMP when the feature extractor networks are initialized from the pretrained ones by the contrastive loss training. We initially expected that this warm-start training may expedite the training with the SwAMP loss, however, as the results in
[0143] Impact of the entropic regularization (1/η). In the Sinkhorn-Knopp (SK) algorithm, we have the reciprocal trade-off 1/η for the entropy term of the optimization variables q(y|x). Too much emphasizing the entropy term (by increasing 1/η or decreasing η) would lead to near uniform q(y|x), which means that it carries little information about the meaningful classes, and cluster assignment can be more or less random. On the other hand, having too small impact of the entropy term would make the SK algorithm converge too slowly, and the output of the SK with only a few iterations would produce non-optimal solutions. To see the impact, we vary 1/η from 0.01, 0.05 (default), and 0.1, and the results are shown in
[0144] Soft or hard cluster assignment after OT. As is known in the art, hard cluster assignment means to assign a cluster label to a data point in all-or-nothing manner, whereas soft assignment is to assign probability values. Merely as an example, for the case in which K=3 (i.e. there are clusters 1,2 and 3), hard cluster assignment means that a data point is assigned to clusters (0,1,0) meaning that it belongs to the cluster 2, nothing else. By contrast for soft cluster assignment, we may assign a data point is assigned to clusters (0.1,0.8,0.1) meaning that the probability of belonging to cluster 2 is 80% and the rest 10%. So the soft cluster assignment method allows uncertainty in cluster assignment and means that the objective remains differentiable which is not true for hard cluster assignment.
[0145] We also check if the hard cluster assignment thresholding after OT optimization would be beneficial or not. Recall that the default is to use the output q(y|x) of the SK algorithm as it is (i.e., soft cluster assignment). In the hard assignment we further threshold q(y|x) to have one-hot encoding, which is then used in the cross-entropy loss optimization. As shown in
Text-Based Video Retrieval
[0146] We consider the text-to-video retrieval task where the goal is to find the most relevant video clip for a given natural language text query. We consider three datasets for this task: i) YouCook2 [“Towards automatic learning of procedures from web instructional videos” by Zhou et al published in The Thirty-Second AAAI Conference on Artificial Intelligence in 2018] of cooking videos and instructions, ii) MSR-VTT [“A large video description dataset for bridging video and language” by Xu et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2016] of generic videos and captions from YouTube, and iii) LSMDC [“Movie description” by Rohrbach et al published in International Journal of Computer Vision, 123: 94-120 in 2017] of movie clips and subtitles. All these datasets provide pairs of video clip and its text description, forming a multi-modal paired data format (text, video) which conforms to our SwAMP framework.
[0147] For the raw text/video features and the feature extractor networks, as well as the training/test protocols, we follow the methods in “HowTo100 M: Learning a Text-Video Embedding by Watching Hundred Million Narrated Video Clips” by Miech et al published in International Conference on Computer Vision in 2019. The features are specifically built by the following procedures. First, the raw features are obtained by the pretrained networks. The raw video features (4096D) are a concatenation of frame-level and video-level features extracted from the pretrained 2D/3D CNNs. The 2D features may be extracted using the ImageNet pre-trained Resnet-152, for example as described in “Momentum Contrast for Unsupervised Visual Representation Learning” by He et al published as arXiv preprint arXiv: 1911.05722 in 2016. The Kinetics may be extracted using any suitable technique, for example the technique described in “Quo vadis, action recognition? A new model and the kinetics dataset” by Carreira and Zisserman published in IEEE/CVF Conference on Computer Vision and Pattern Recognition 2017. The 3D features may be extracted using the pre-trained ResNeXt-101 16-frame model, for example as described in “Can spatiotemporal 3D CNNs retrace the history of 2D CNNs and ImageNet?” by Hara, Kataoka, and Satoh published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2018 for 3D features). The raw text features (4096D) may be any suitable features, for example the GoogleNews pre-trained word2vec embeddings described in “Efficient estimation of word representations in vector space” by Mikolov et al available as arXiv preprint arXiv:1301.3781 in 2013 for the pre-processed transcribed video narrations with the common stop words removed.
[0148] Then the feature extractor networks ϕ.sup.video(.Math.) and ϕ.sup.text(.Math.) transform these raw features into 4096D features (i.e. a feature vector of 4096 dimensions) by the sigmoid-gated linear transform where the gating functions are two-layer linear networks, for example as described in “Learning a Text-Video Embedding from Incomplete and Heterogeneous Data” by Miech et al published in arXiv preprint arXiv:1804.02516 in 2018. We fix the raw features and train only the latter sigmoid-gated networks, which comprise about 67 M parameters.
[0149] There are two training strategies: i) No-pretraining (No-PT) where the feature extraction networks are randomly initialized, and the training is done on the training split of the dataset, and ii) Pretraining (PT) where the feature extractors are first pretrained on the large-scale HowTo100 M dataset as described in the paper by Miech published in 2019, and finetuned on the target dataset. The prior approach by Miech adopts the contrastive (triplet) loss for training the feature extractors. Although we also compare our approach with the state-of-the-arts, the main focus in this experiment is to demonstrate the performance improvement achieved by the proposed SwAMP loss against vanilla contrastive learning. The SwAMP hyperparameter λ, the weight/impact of the SwAMP loss against the contrastive loss is chosen as λ = 0.25 for all three datasets, except the LSMDC-PT case for which λ= 0.1. We also choose temperature in softmax τ = 0.25, entropic regularization trade-off in SK η = 5.0, the number of classes K = 500, and the queue size 2,048 for the SwAMP. The other learning hyperparameters common in SwAMP and contrastive losses are not changed from the prior art approach of Miech.
[0150] YouCook2. This cooking video dataset collected from YouTube, contains 89 recipes and 14 K video clips annotated with textual descriptions from paid human workers. The test data are formed by taking 3.5 K clips from the validation set, and the test set comprises of 3,350 pairs. The retrieval performance metrics are recall-at-k (R@k) with k = 1,5,10 and the median rank (Med-R). Hence, the random guess attains R@1= 0.03% Med-R=1,675. The results are summarized in Table 2. In the bottom four rows, we see the performance improvement achieved by the proposed SwAMP against the contrastive loss [Miech et al 2019]. For both training strategies, No PT (random model initialization) and PT (initialized with the HowTo100 M-pretrained model), our SwAMP improves the retrieval performance significantly (e.g., about 12% reduction in Median Rank for the No PT case). SwAMP also outperform the CCA baseline FV-CCA described in “Associating neural word embeddings with deep image representations using Fisher vectors” by Klein et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2015.
TABLE-US-00002 Text-video retrieval results on YouCook2 Methods R@1 ↑ R@5 ↑ R@10 ↑ Med-R ↓ Random 0.03 0.15 0.3 1675 FV-CCA 4.6 14.3 21.6 75 Contrastive (No PT) 4.2 13.7 21.5 65 SwAMP (No PT) 4.8 14.5 22.5 57 Contrastive (PT) 8.2 24.5 35.3 24 S.sub.WAMP (PT) 9.4 24.9 35.3 22
[0151] In the table above, and each of tables 3 and 4 below, the improved scores of SwAMP over contrastive are boldfaced. It is also noted that R@K is the % of a true item found in model’s top-K retrieved items and Med-R is the average median rank of a true item in the model’s ranking of all search items. For these tasks, the query modality is the text (i.e. the caption) and the search modality is the video.
[0152] MSRVTT. This generic video-text dataset collected from YouTube contains videos of specific categories including music, sports, and movie. There are 200 K video-caption pairs obtained by human annotation. We follow the retrieval training/test protocol described in the prior art approaches of Yu and Miech. The test set consists of 1 K pairs and the results are reported in Table 3. The results of the SwAMP approach are compared to prior art approaches C+LSTM+SA+FC7 described in “Learning language visual embedding for movie understanding with natural language” by Torabi et al published as arXiv preprint arXiv: 1609.08124 in 2016, VSE-LSTM described in “Unifying visual-semantic embeddings with multimodal neural language models” by Kiros et al published as arXiv preprint arXiv: 1411.2539 in 2014, Temporal Tessellation described in “Temporal tessellation: A unified approach for video analysis” by Kauman et al published in Proceedings of the IEEE International Conference on Computer Vision in 2017, CT-SAN described in “End-to-end concept word detection for video captioning, retrieval, and question answering” by Yu et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2017, and JSFusion described in “A joint sequence fusion model for video question answering and retrieval” by Yu et al published in European Conference on Computer Vision in 2018.
TABLE-US-00003 Text-video retrieval results on MSRVTT Methods R@1 ↑ R@5 ↑ R@10 ↑ Med-R ↓ Random 0.1 0.5 1.0 500 C+LSTM+SA+FC7 4.2 12.9 19.9 55 VSE-LSTM 3.8 12.7 17.1 66 SNUVL 3.5 15.9 23.8 44 Temporal Tessellation 4.7 16.6 24.1 41 CT-SAN 4.4 16.6 22.3 35 JSFusion 10.2 31.2 43.2 13 Contrastive (No PT) 12.1 35.0 48.0 12 SwAMP (No PT) 15.0 38.5 50.3 10 Contrastive (PT) 14.9 40.2 52.8 9 S.sub.WAMP (PT) 19.0 42.4 55.2 8
[0153] As reported in Table 3, our SwAMP loss improves the performance over the contrastive learning significantly for both no-pretraining and pretraining cases: about 24% in R@1 in the No PT case, and 27% in the PT case. Furthermore, the SwAMP outperforms by a large margin all of the state-of-the-art approaches.
[0154] LSMDC. The LSMDC is a dataset of movie video clips, comprised of 101K video-caption pairs. The captions are collected either from the movie scripts or the audio descriptions. The test set contains 1 K pairs and the results are reported in Table 4 alongside the results from the prior art approaches described in relation to Table 3. For this LSMDC dataset, we use the SwAMP hyperparameter (impact of the SwAMP loss against the contrastive loss) λ = 0.1 for the PT case. Similar to the other two datasets, our SwAMP is consistently better than the contrastive learning (about 7 ~ 9% in Median Rank).
TABLE-US-00004 Text-Video retrieval results on LSMDC Methods R@1 ↑ R@5 ↑ R@10 ↑ Med-R ↓ Random 0.1 0.5 1.0 500 C+LSTM+SA+FC7 4.3 12.6 18.9 98 VSE-LSTM 3.1 10.4 16.5 79 SNUVL 3.6 14.7 23.9 50 Temporal Tessellation 4.7 15.9 23.4 64 CT-SAN 4.5 14.1 20.9 67 JSFusion 9.1 21.2 34.1 36 Contrastive (No PT) 7.2 18.3 25.0 44 SwAMP (No PT) 7.7 19.3 27.7 40 Contrastive (PT) 7.1 19.6 27.9 40 S.sub.WAMP (PT) 8.3 20.0 28.9 37
Sketch-Based Image Retrieval
[0155] We next test the SwAMP approach to the sketch-based image retrieval task. The model takes a user’s sketch (quick drawing) of an object as an input query, and retrieves the photo images that correspond to the same object category as that of the query. Thus, for this task, the query modality is the sketch image (i.e. the human’s quick drawing) and the search modality is the photo image. Sketch-to-image retrieval gains a lot of attention these days due to the pervasive availability of the touch screen or similar drawing devices.
[0156] We follow the recent framework from “Doodle to Search: Practical Zero-Shot Sketch-based Image Retrieval” by Dey et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) in 2019. This article also reports the state-of-the-art performance on the three large-scale sketch-image benchmark datasets: the Sketchy-Extended dataset described in “The sketchy database: Learning to retrieve badly drawn bunnies” by Sangklov et al published in ACM Transactions on Graphics, 35(4): 1-12.Schroff, F in 2016; the TU-Berline-Extended dataset described in “How do humans sketch objects?” by Eitz et al published in ACM Transactions on Graphics, 31(4): 1-10 in 2012, and the QuickDraw-Extended dataset which is described in the article by Dey et al. The datasets roughly consist of 100 ~ 200 object categories with hundreds to thousands of sketch/photo images for each category. For all these datasets, we have zero-shot setting, meaning that training/test splits have instances from disjoint object categories.
[0157] In this experiment we aim to show the improvement in the retrieval performance when our SwAMP loss is augmented to the existing loss function. To this end, we follow the same embedding networks for images and sketches, as well as the same loss function as the Doodle2Search. The loss function consists of three losses: Triplet loss is the conventional triplet loss, Domain loss uses adversarial domain classifier to penalize misalignment between distributions of embeddings of photo images and sketches, and Semantic loss urges the embeddings of the photo images and sketches to reconstruct the pretrained word embedding of the corresponding object word. We also use the same attention-based embedding networks for photo and sketch modalities. Then, we add our SwAMP loss to the Doodle2Search’s loss with the impact λ = 0.1 for all three datasets. The combined loss function may thus be defined as:
where the first and the last terms are the same as equation (7), and L.sub.d is the domain loss, L.sub.sm is the semantic loss, and α are mixing proportions.
[0158] We use the queue size 1,000 (2,000 for the QuickDraw-Extended dataset) and class cardinality K = 500, softmax temperature τ = 0.25, entropic regularization impact η = 5.0. The resulting retrieval performances on the three datasets are summarized in Table 5. The performance is compared to three prior art methods: Doodle2Search (denoted by D2S) which is described in the Dey article; ZSIH described in “Zero-shot sketch- image hashing” by Shen et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition and CVAE described in “A zero-shot framework for sketch based image retrieval” by Yelamarthi et al published in European Conference on Computer Vision in 2018.
TABLE-US-00005 Sketch-based image retrieval results on three sketch datasets Methods Sketchy-Extended [Sangkloy et al 2016] TU-Berline-Extended [Eitz, Hays and Alexa 2012] mAP mAP@200 P@200 mAP mAP@200 P@200 ZSIH 25.40 22.00 CVAE 19.59 22.50 33.30 0.50 0.90 0.30 D2S 36.91 46.06 37.04 10.94 15.68 12.08 SwAMP 40.32 51.94 40.81 17.63 24.49 19.75
[0159] In the table above, the improved scores are marked in bold. As shown, our SwAMP loss when added to the existing contrastive-based loss described in the Dey article, significantly improves the retrieval performance (about 9% in mAP for the Sketchy dataset and about 60% for the TU-Berlin dataset). The metrics are mean average precision (mAP), mAP@200 which is the mean average precision of the items found in the model’s top 200 retrieved items and P@200 which is the precision of the items found in the model’s top 200 retrieved items.
Image-Text Retrieval
[0160] For the image-text cross-modal retrieval task, we follow the features and protocols from the well-known SCAN paper “Stacked Cross Attention for Image-Text Matching” by Lee et al published in European Conference on Computer Vision in 2018. Thus, each image may be represented by a set of local features v.sub.i:
[0161] with v.sub.i (∈ ℝ.sup.D) = W.sub.vf.sub.i + b.sub.v and [0162] where the raw features f.sub.i s are fixed and {W.sub.v, b.sub.v} are learnable parameters. The f.sub.is are the CNN features extracted from salient image regions detected by the Faster-R-CNN model described in “Faster R-CNN: Towards real-time object detection with region proposal networks” by Ren et al published in Advances in Neural Information Processing Systems in 2015. The text (sentence) is also treated as a set of word features e.sub.i: [0163] where [0164] where are the outputs of the bi-directional GRU (gated recurrent unit) with the sequence of word embeddings as input. These outputs may be determined as explained in “Neural machine translation by jointly learning to align and translate” by Bahdanau et al published in International Conference on Learning Representations in 2015 or “Bidirectional recurrent neural networks” by Schuster et al published in IEEE Transactions on Signal Processing, 45(11): 2673-2681 in 1997. Both the word embeddings and GRU parameters are learnable. These image/text features contain rich local information, however, one challenge is that both representations are sets, hence the number of elements (k and n) can vary from instance to instance.
[0165] In the original SCAN paper, they proposed a cross-modal attention model, where each local feature from one modality is transformed by the attention with the set of local features in the other modality; e.g., [0166] v.sub.i is transformed to = the weighted sum of values [0167] with v.sub.i as a query and as keys (this denoted by i-t) or [0168] e.sub.i is transformed to = the weighted sum of values [0169] with e.sub.i as a query and as keys (this denoted by t-I and can be used alternatively to the previous transformation). Then the similarity score between image V and text E is defined as [0170] where cos(a, b) is the cosine similarity and pool is the pooling operation, either of AVG or LSE (log-sum-exp). Further information on the transformation is provided in “Attention Is All You Need” by Vaswani et al published in Advances in Neural Information Processing Systems in 2017. Then the triplet contrastive loss which is described in the Doodle2Search is employed.
[0171] Note that in the SCAN, there is no succinct modality-wise embedding vector representation, but the similarity score between instances of two modalities is rather computed by highly complex attention operations. Although this is helpful for capturing the interactions between local features, computing the similarity score takes quadratic time in the number of elements (local features) in the instances. This is time consuming compared to simple dot-product of the modality-wise embedding vectors (See Table 6 for the actual running times compared with the approaches based on modality-wise feature representation). Moreover, it is not applicable to our SwAMP approach since we need to predict the class labels for each modality from modality-wise representation ϕ.sup.image(V), ϕ.sup.text(E).
[0172] To have modality-wise representation, we adopt the idea of induced-set attention (ISA) from the Set Transformer which is described in “Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks” by Lee et al published in Proceedings of the 36th International Conference on Machine Learning, 3744-3753 in 2019. Specifically, we introduce p learnable prototype (query) vectors
where q.sub.j ∈ ℝ.sup.D. In other words, the query vectors are learnable parameters which are used in feature extraction. We compute the attention for each query with the set of local features V for the image, i.e.,
[0173] Then we define:
where concat refers to concatenation. Thus the parameters for ϕ.sup.image () are {W.sub.v, b.sub.v} and the learnable prototpye vectors
[0174] Similarly for the set of local features E for the image, we compute the attention for each query:
[0175] Then we define:
where concat refers to concatenation. Thus the parameters for ϕ.sup.text () are the word embeddings, GRU parameters, and
We share the same
for both modalities. We also have multi-head extension by computing these features multiple times and concatenating them. We call these modality-wise features a prototype attention representation (PAR). Note that computing PAR features has linear complexity in the number of local features (assuming p is constant), and the cross-modal similarity is simply dot-product of PAR features, and can be computed in linear time (See also Table 8 below). Computing the PAR features is part of the feature extraction and may thus be part of the step S104 in
Datasets and Results
[0176] We test our approach on the popular image-text retrieval datasets, MS-COCO and Flickr30K. There are 31K images and five captions for each image in Flickr30K. MS-COCO contains 123,287 images, where each image is annotated with five text descriptions. Following the widely-used split described for example in “VSE++: Improved visual-semantic embeddings with hard negatives” by Faghri et al published in Proc. of British Machine Vision Conference in 2018, for the Flickr30K, we have 1 K images for validation, 1 K images for testing, and the rest for training. For MS-COCO, there are 5 K test images (and 25 K captions, five captions for each image). We also follow two standard protocols for measuring the test retrieval performance for MS-COCO: 1) using the entire 5 K test images or 2) splitting the test set into 5 folds and report the average retrieval performance over the 5 folds.
[0177] The results are summarized in Table 6 (Flickr) and Table 7 (MS-COCO). In both tables, the results are compared with known prior art methods, including: DAN described in “Dual attention networks for multimodal reasoning and matching” by Nam et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2017, DPC described in “Dual-path convolutional image-text embedding” by Zheng et alpublished as arXiv preprint arXiv: 1711.05535 in 2017; VSE described in “VSE++: Improved visual-semantic embeddings with hard negatives” by Faghri et al published in Proc. of British Machine Vision Conference in 2018, SCO described in “Learning semantic concepts and order for image and sentence matching” by Huang et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2018; GXN described in “Look, imagine and match: Improving textual-visual cross-modal retrieval with generative models” by Gu et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2018] and PCN described in “Probabilistic Embeddings for Cross-Modal Retrieval” by Chun et al published in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 8415-8424 in 2021. The results are also compared with different variations of the methods described in the original SCAN paper (i.e. “Stacked cross attention for image-text matching” by Lee et al published in the European Conference on Computer Vision in 2018. For example, SCAN i-t refers to transforming the features of the image modality by the attention with the features of the text modality and SCAN t-i refers to transforming the features of the text modality by the attention with the features of the image modality. AVG indicates that a pooling function in the form of an averaging function is used when calculating the similarity score and LSE indicates that a pooling function in the form of an log-sum-exp function is used when calculating the similarity score.
TABLE-US-00006 Image-text retrieval results on Flickr30K Methods Image .fwdarw. Text Text .fwdarw. Image R@1 R@5 R@10 R@1 R@5 R@10 DAN 55.0 81.8 89.0 39.4 69.2 79.1 DPC 55.6 81.9 89.5 39.1 69.2 80.9 VSE++ 52.9 87.2 39.6 79.5 SCO 55.5 82.0 89.3 41.1 70.5 80.1 SCAN i-t AVG 67.9 89.0 94.4 43.9 74.2 82.8 SCAN t-i AVG 61.8 87.5 93.7 45.8 74.4 83.0 SCAN t-i AVG + i-t LSE 67.4 90.3 95.8 48.6 77.7 85.2 Contrastive-PAR 65.7 86.8 92.4 48.2 75.8 84.2 SwAMP-PAR 67.8 88.5 94.0 49.1 76.1 83.7
TABLE-US-00007 Image-text retrieval results on MS-COCO Methods 5-fold (1 K test images) Image .fwdarw. Text Text .fwdarw. Image R@1 R@5 R@10 R@1 R@5 R@10 DPC 65.6 89.8 95.5 47.1 79.9 90.0 VSE++ 64.6 - 95.7 52.0 92.0 GXN 68.5 - 97.9 56.6 - 94.5 SCO 69.9 92.9 97.5 56.7 87.5 94.8 PCME 68.8 - - 54.6 - - SCAN i-t 69.2 93.2 97.5 54.4 86.0 93.6 SCAN t-i + i-t 72.7 94.8 98.4 58.8 88.4 94.8 Contrastive-PAR 71.8 94.3 97.9 56.8 86.9 93.8 SwAMP-PAR 72.6 94.6 98.0 57.4 87.6 94.1 Methods Entire (5 K test images) Image .fwdarw. Text Text .fwdarw. Image R@1 R@5 R@10 R@1 R@5 R@10 DPC 41.2 70.5 81.1 25.3 53.4 66.4 VSE++ 41.3 81.2 30.3 72.4 GXN 42.0 84.7 31.7 74.6 SCO 42.8 72.3 83.0 33.1 62.9 75.5 PCME 44.2 31.9 SCAN i-t 46.4 77.4 87.2 34.4 63.7 75.7 SCAN t-i + i-t 50.4 82.2 90.0 38.6 69.3 80.4 Contrastive-PAR 48.4 78.1 88.1 34.3 64.4 76.2 SwAMP-PAR 49.7 79.1 88.3 35.0 65.1 76.6
[0178] In the tables 6 and 7 above, the improved scores of SwAMP over contrastive are boldfaced. It is also noted that R@K is the % of a true item found in model’s top-K retrieved items. For these tables, the query modality is the image and the search modality is the text (i.e. the caption or short description of the image) or the query modality is the text and the search modality is the image and denoted by the labels Image .fwdarw. Text and Text .fwdarw. Image, respectively.
[0179] We specifically highlight the comparison between the contrastive loss and our SwAMP loss with the modality-wise feature representation (Contrastive-PAR vs. SwAMP-PAR). For the PAR features, we choose the number of prototypes p = 20, attention weight temperature T = 0.5, and the number of heads H = 1 for Flickr, and p = 10, T = 0.5, H = 2 for MS-COCO. For the SwAMP hyperparameters, we use the impact of SwAMP loss λ = 1.0, softmax temperature τ = 0.025, the number of classes K = 1,000, queue size 1,280 for both datasets. As shown, the SwAMP loss performs consistently better than the contrastive loss. SwAMP also outperforms several state-of-the-arts including the recent sophisticated probabilistic embedding strategy labelled PCN.
[0180] When compared with the computationally expensive SCAN, SwAMP mostly outperforms SCAN except for the SCAN’s best attention direction/combination choices. Note that SwAMP uses the simple feature aggregation strategy (PAR) to have fast and succinct modality-wise feature representation, whereas SCAN relies on the cross-modal attention similarity scoring model, which is computationally expensive.
[0181] To see the computational advantage of SwAMP-PAR, we compare the actual training/test times for the two approaches in Table 8, measured on the same machine with a single GPU (RTX 2080 Ti), Core i7 3.50 GHz CPU, and 128 GB RAM. We report per-batch times for training, and entire retrieval times for test. For MS-COCO test, the running times for 5 K test images are reported, where times for 1 K test images averaged over 5 folds are shown in the parentheses. For SCAN, when we use features in both directions (e.g., t-i AVG + i-t LSE), the running times are roughly doubled. As shown, our SwAMP-PAR is about 4 times faster than SCAN for training on both datasets, while the difference becomes even more pronounced during test; SwAMP-PAR is about two orders of magnitude faster than the cross-modal attention model.
TABLE-US-00008 Running time comparison for SCAN (cross-modal attention) and our SwAMP-PAR Methods Flickr MS-COCO Train Test Train Test SCAN i-t AVG 0.35 336.9 0.33 9352.0 (350.3) SwAMP-PAR 0.09 3.8 25.9 (16.3)
[0182] The results above show the improvements offered by the proposed novel clustering-based loss function for cross-modal retrieval. The swapped class assignment over the modalities enables improved feature alignment with increased flexibility, while it helps reducing the sampling complexity significantly. The efficacy of our approach is demonstrated on several real-world cross-modal retrieval problems in diverse modalities, text-video, sketch-photo, and image-text and achieved significant performance improvement over the contrastive learning for all these tasks.
[0183] Once the model is trained, it is output to a device for use.
[0184] The apparatus comprises the standard components, including for example at least one processor 102 coupled to memory 104. The at least one processor 102 may comprise one or more of: a microprocessor, a microcontroller, and an integrated circuit. The memory 104 may comprise volatile memory, such as random access memory (RAM), for use as temporary memory, and/or non-volatile memory such as Flash, read only memory (ROM), or electrically erasable programmable ROM (EEPROM), for storing data, programs, or instructions, for example.
[0185] The apparatus may further comprise at least one image capture device 108 for capturing images or videos to be processed by the ML model. The apparatus may further comprise at least one interface 110 for a user to input other data to be processed by the ML model, e.g. a text query or a sketch query. The at least one interface may also provide a result of the processing by the ML model to a user of the apparatus. For example, the apparatus 100 may comprise a display screen to receive user inputs and to display the results of implementing the ML model 106.
[0186] As demonstrated above, the SwAMP approach can be applied to different types of cross-modal retrieval problems. Moreover, as empirically demonstrated, the SwAMP loss improves retrieval performance significantly over the contrastive learning, on various real-world cross-modal retrieval problems, including text-video, sketch-image, and image-text retrieval.
[0187] There are two main benefits of the SwAMP approach: i) Since the learning does not fully resort to pair-based losses as in contrastive learning, the sampling complexity is reduced. This comes from the class-based loss adopted in the SwAMP. ii) Unlike the contrastive loss, SwAMP does not make potentially wrong assumption that instances from different pairs are automatically irrelevant. The optimized class assignment finds similar instances from other pairs, and the feature extractor is trained in such a way that the same-class instances, even in different pairs, are well aligned. This feature of aligning instances in different pairs is hardly exploited in the contrastive loss.
[0188] As discussed previously, there are broadly two different ways to define the similarity metric between instances of different modalities: modality-wise feature representation and cross-modal attention. Examples of the cross-modal attention approach are described in “Stacked Cross Attention for Image-Text Matching” by Lee et al published in European Conference on Computer Vision in 2018, “VirTex: Learning Visual Representations from Textual Annotations” by Desai et al published as arXiv preprint arXiv:2006.06666 in 2020, “Pixel-BERT: Aligning image pixels with text by deep multi-modal transformers” by Huang et al published as arXiv preprint arXiv:2004.00849 in 2020 and “Vilbert: Pretraining task-agnostic visiolinguistic representations for vision-and-language tasks” by Lu et al published in Advances in Neural Information Processing Systems in 2019. The main benefit of the former is the computational efficiency, scalable to billions of instances at training/test time, thanks to the efficient dot-product. The latter directly computes the similarity score without having modality-wise representation using the transformer-like attentive neural networks. Although they can capture cross-modal interactions between local features of data instances from different modalities, they are computationally demanding and very slow due to the quadratic complexity in the number of local features. “Thinking Fast and Slow: Efficient Text-to-Visual Retrieval With Transformers” by Miech et al published in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 9826-9836 in 2021 describes an hybrid of the two approaches which retains the two models, but performs re-ranking/distillation at test time for speed-up.
[0189] Clustering-based approaches. There were previous attempts to cluster (group) data instances, or equivalently self-labeling, to improve saliency in representation learning. Some approaches such as those described in “Deep Clustering for Unsupervised Learning of Visual Features” by Caron et al published in European Conference on Computer Vision in 2018 or “Self-Supervised Learning by Cross-Modal Audio-Video Clustering” by Alwassel et al published in Advances in Neural Information Processing Systems in 2020 perform offline K-means clustering for every epoch, which can make training slow. The idea of optimizing class labels in the representation learning was previously introduced in “Self-labelling via simultaneous clustering and representation learning” by Asano et al published in International Conference on Learning Representations in 2020 and “Unsupervised Learning of Visual Features by Contrasting Cluster Assignments” by Caron et al published in Advances in Neural Information Processing Systems in 2020. However, all these previous approaches aimed for self-supervised representation learning as an instance discrimination pretext task with augmented data. On the other hand, we perform simultaneous learning of class labels and the feature extraction networks for the cross-modal retrieval setting.
[0190] Proxy-based loss. Beyond the pair-based contrastive loss, a loss function based on the semantic class labels of data instances, known as proxy-based loss can also be used. As explained above, the proxy-based methods introduce learnable proxy vectors (class representatives), one for each class, and pull together data instances that belong to the same class toward the proxies. As the loss function is defined solely with distances between data instances and proxy vectors without pairwise distances, it reduces the sampling complexity to linear. The idea has been introduced in deep metric learning approaches such as proxy-NCA (described in “No Fuss Distance Metric Learning using Proxies” by Movshovitz-Attias published in International Conference on Computer Vision in 2017), SoftTriple (described in “Soft- triple loss: Deep metric learning without triplet sampling” by Qian et al published in Proceedings of the IEEE International Conference on Computer Vision in 2019), and the proxy-anchor (described in “Proxy Anchor Loss for Deep Metric Learning” by Kim et al published in IEEE/CVF Conference on Computer Vision and Pattern Recognition in 2020). However, unlike our new method described above (SwAMP), they deal with the supervised setup where the ground-truth semantic class labels are provided.
[0191] Those skilled in the art will appreciate that while the foregoing has described what is considered to be the best mode and where appropriate other modes of performing present techniques, the present techniques should not be limited to the specific configurations and methods disclosed in this description of the preferred embodiment. Those skilled in the art will recognise that present techniques have a broad range of applications, and that the embodiments may take a wide range of modifications without departing from any inventive concept as defined in the appended claims.