-
Notifications
You must be signed in to change notification settings - Fork 107
Sinkhorn loss implementatin #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 0.3
Are you sure you want to change the base?
Changes from all commits
e5f50f4
9bcaef4
2ad35c1
da65e74
f444a10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| Sinkhorn Loss | ||
| =============== | ||
| .. currentmodule:: pina.loss.sinkhorn_loss | ||
|
|
||
| .. automodule:: pina._src.loss.sinkhorn_loss | ||
| :no-members: | ||
|
|
||
| .. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss | ||
| :members: | ||
| :show-inheritance: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| """Module for the SinkhornLoss class.""" | ||
|
|
||
| import torch | ||
| from pina._src.loss.base_dual_loss import BaseDualLoss | ||
| from pina._src.core.utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class SinkhornLoss(BaseDualLoss): | ||
| r""" | ||
| Implementation of the Sinkhorn Loss based on regularized optimal transport. | ||
| It measures the regularized Wasserstein distance between the empirical | ||
| distributions represented by ``input`` (with :math:`N` samples) and | ||
| ``target`` (with :math:`M` samples), each in :math:`\mathbb{R}^D`. | ||
|
|
||
| The loss solves the entropy-regularized optimal transport problem: | ||
|
|
||
| .. math:: | ||
| W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} | ||
| \langle C, \pi \rangle - \varepsilon H(\pi), | ||
|
|
||
| where :math:`C_{ij} = \|x_i - y_j\|_2^p` is the cost matrix, | ||
| :math:`H(\pi) = -\sum_{ij} \pi_{ij} \log \pi_{ij}` is the entropy of | ||
| the transport plan, and :math:`\varepsilon > 0` is the regularization | ||
| strength. The dual objective recovered by the Sinkhorn iterations is: | ||
|
|
||
| .. math:: | ||
| W_\varepsilon = \langle a, f^* \rangle + \langle b, g^* \rangle, | ||
|
|
||
| where :math:`a` and :math:`b` are uniform probability weights over the | ||
| :math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are | ||
| the optimal dual potentials computed via log-space Sinkhorn iterations. | ||
|
|
||
|
|
||
| .. note:: | ||
|
guglielmopadula marked this conversation as resolved.
|
||
| Unlike pointwise losses, the Sinkhorn loss operates on entire empirical | ||
| distributions, so the output is always a scalar regardless of the | ||
| number of samples. | ||
|
|
||
| .. note:: | ||
| Smaller values of ``eps`` approximate the true Wasserstein distance | ||
| more closely but may require more iterations to converge. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference:** Patrini, G., Berg, R., Forre, P., Carlin, M., | ||
| and Bhatt, W. *Sinkhorn AutoEncoders*. In Proceedings of the Thirty-Fifth | ||
| Conference on Uncertainty in Artificial Intelligence (UAI), 2019. | ||
| `arXiv:1810.01118 <https://arxiv.org/abs/1810.01118>`_. | ||
| """ | ||
|
|
||
| def __init__(self, p=2, eps=0.1, max_iter=100): | ||
| """ | ||
| Initialization of the :class:`SinkhornLoss` class. | ||
|
|
||
| :param int p: Exponent of the cost function. Default is ``2``. | ||
| :param eps: Entropy regularization strength | ||
| :type eps: int | float | ||
| :math:`\varepsilon > 0`. Default is ``0.1``. | ||
| :param int max_iter: Number of Sinkhorn iterations. Default is ``100``. | ||
| :raises ValueError: If ``p`` is not a numeric value. | ||
| :raises ValueError: If ``eps`` is not a positive number. | ||
| :raises AssertionError: If ``max_iter`` is not a strictly positive int. | ||
| """ | ||
| super().__init__(reduction="mean") | ||
|
|
||
| check_consistency(p, (int, float)) | ||
| check_consistency(eps, (int, float)) | ||
| if eps <= 0: | ||
| raise ValueError( | ||
| f"eps must be a strictly positive float, got {eps}." | ||
| ) | ||
| check_positive_integer(max_iter, strict=True) | ||
|
|
||
| self.p = p | ||
| self.eps = eps | ||
| self.max_iter = max_iter | ||
|
|
||
| def forward(self, input, target): | ||
| """ | ||
| Forward method of the loss function. | ||
|
|
||
| :param torch.Tensor input: Input tensor of shape :math:`(N, D)`. | ||
| :param torch.Tensor target: Target tensor of shape :math:`(M, D)`. | ||
| :return: Sinkhorn loss value. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| n = input.shape[0] | ||
| m = target.shape[0] | ||
|
|
||
| a = input.new_full((n,), 1.0 / n) | ||
| b = target.new_full((m,), 1.0 / m) | ||
| f = torch.zeros(n, dtype=input.dtype, device=input.device) | ||
| g = torch.zeros(m, dtype=target.dtype, device=target.device) | ||
|
|
||
|
|
||
| # Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M) | ||
| diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D) | ||
| C = torch.linalg.norm(diff, ord=2, dim=-1).pow(self.p) # (N, M) | ||
|
|
||
| # Log-space Sinkhorn iterations for numerical stability | ||
| log_a = a.log() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep a,b snd not pass the log space (or if we pass the log space maybe add a bolean Flag log_space_opt)?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, working in log space allows us to use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variables a and b are used in the currently line 124. |
||
| log_b = b.log() | ||
|
|
||
|
|
||
| for _ in range(self.max_iter): | ||
| # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) | ||
| f = self.eps * ( | ||
| log_a - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) | ||
| ) | ||
| # g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps)) | ||
| g = self.eps * ( | ||
| log_b - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) | ||
| ) | ||
|
|
||
| loss = (a * f).sum() + (b * g).sum() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, this loss can take negative values. Please confirm whether this is expected behavior; otherwise, adjust the implementation accordingly.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loss can indeed take negative values. |
||
| return self._reduction(loss.unsqueeze(0)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import torch | ||
| import pytest | ||
|
|
||
| from pina.loss import SinkhornLoss | ||
|
|
||
| # Fixed random tensors for reproducibility | ||
| torch.manual_seed(0) | ||
| input_ = torch.rand(10, 2) | ||
| target_ = torch.rand(8, 2) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("p", [1, 2]) | ||
| @pytest.mark.parametrize("eps", [0.01, 0.1]) | ||
| @pytest.mark.parametrize("max_iter", [10, 100]) | ||
| def test_constructor(p, eps, max_iter): | ||
|
|
||
| SinkhornLoss(p=p, eps=eps, max_iter=max_iter) | ||
|
|
||
| # Should fail if p is not numeric | ||
| with pytest.raises(ValueError): | ||
| SinkhornLoss( | ||
| p="invalid", eps=eps, max_iter=max_iter | ||
| ) | ||
|
|
||
| # Should fail if eps is not numeric | ||
| with pytest.raises(ValueError): | ||
| SinkhornLoss(p=p, eps="bad", max_iter=max_iter) | ||
|
|
||
|
|
||
| # Should fail if eps is not positive | ||
| with pytest.raises(ValueError): | ||
| SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter) | ||
|
|
||
| # Should fail if max_iter is not a positive integer | ||
| with pytest.raises(AssertionError): | ||
| SinkhornLoss(p=p, eps=eps, max_iter=0) | ||
|
|
||
|
|
||
|
|
||
| @pytest.mark.parametrize("p", [1, 2, 3]) | ||
| @pytest.mark.parametrize("eps", [0.01, 0.1, 1.0, 1]) | ||
| @pytest.mark.parametrize("max_iter", [10, 100]) | ||
| @pytest.mark.parametrize("input, target", [ | ||
| (torch.rand(10, 2), torch.rand(8, 2)), # different N, M; 2D features | ||
| (torch.rand(5, 3), torch.rand(5, 3)), # same N=M; 3D features | ||
| (torch.rand(1, 4), torch.rand(7, 4)), # single input sample | ||
| (torch.rand(6, 4), torch.rand(1, 4)), # single target sample | ||
| (torch.rand(3, 1), torch.rand(4, 1)), # 1D feature space | ||
| (input_, input_), # identical distributions | ||
| (input_, target_), # different distributions | ||
| ]) | ||
| def test_forward(p, eps, max_iter, input, target): | ||
| # Output must always be a scalar and finite. The (non-debiased) Sinkhorn | ||
| # dual can be negative due to the entropy regularization term; this is | ||
| # expected behavior. eps can be an integer as well as a float. | ||
| loss_fn = SinkhornLoss(p=p, eps=eps, max_iter=max_iter) | ||
| value = loss_fn(input, target) | ||
| assert value.shape == torch.Size([1]) | ||
| assert torch.isfinite(value).all() | ||
| # Sinkhorn loss on identical data should be smaller than on different data. | ||
| if input is input_ and target is target_: | ||
| loss_same = SinkhornLoss(p=p, eps=eps, max_iter=500)(input_, input_) | ||
| assert loss_same.item() < value.item() | ||
|
|
||
|
|
||
| def test_forward_approaches_wasserstein(): | ||
|
|
||
| # For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N | ||
| x = torch.tensor([[1.0], [2.0], [3.0]]) | ||
| y = torch.tensor([[4.0], [5.0], [6.0]]) | ||
| # W_2^2 = ((1-4)^2 + (2-5)^2 + (3-6)^2) / 3 = 9 | ||
| value = SinkhornLoss(p=2, eps=1e-3, max_iter=5000)(x, y) | ||
| assert abs(value.item() - 9.0) < 0.1 |
Uh oh!
There was an error while loading. Please reload this page.