Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding to the training objective (loss function). #4

Open
EryiXie opened this issue Sep 4, 2024 · 1 comment
Open

Question regarding to the training objective (loss function). #4

EryiXie opened this issue Sep 4, 2024 · 1 comment

Comments

@EryiXie
Copy link

EryiXie commented Sep 4, 2024

Hi, thank you for this excellent work, I am already using it in my project which shows impressive results.

I just want to ask about the training objective:

Qinco/utils.py

Lines 28 to 52 in 0ddfc77

def pairwise_distances(a, b):
"""
a (torch.Tensor): Shape [na, d]
b (torch.Tensor): Shape [nb, d]
Returns (torch.Tensor): Shape [na,nb]
"""
anorms = (a**2).sum(-1)
bnorms = (b**2).sum(-1)
return anorms[:, None] + bnorms - 2 * a @ b.T
def compute_batch_distances(a, b):
"""
a (torch.Tensor): Shape [n, a, d]
b (torch.Tensor): Shape [n, b, d]
Returns (torch.Tensor): Shape [n,a,b]
"""
anorms = (a**2).sum(-1)
bnorms = (b**2).sum(-1)
# return anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.einsum('nad,nbd->nab',a,b)
return (
anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.bmm(a, b.transpose(2, 1))
)

, where l2 distance is used in both functions. I am curious about the choice here: l2 distance vs. cosine similarity. Are there some insights behind using l2 distance here?

Best regards.

@mdouze
Copy link
Contributor

mdouze commented Sep 4, 2024

First, cosine similarity and L2 distance are equivalent for normalized vectors (see https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances#how-can-i-index-vectors-for-cosine-similarity), so if queries and database vectors are both normalized, then QINCo will just work as well for cosine as for L2 search.

Regarding this function, although the input vectors are normalized, the centroids are not. Thus, replacing this part of the loss with cosine would yield different results, it's hard to tell what the impact would be.

Intuitively, the centroids should not be normalized because their magnitude decreases with each QINCo step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants