Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ Losses
BaseDualLoss <loss/base_dual_loss.rst>
LpLoss <loss/lp_loss.rst>
PowerLoss <loss/power_loss.rst>
SinkhornLoss <loss/sinkhorn_loss.rst>


Weighting Schemas
--------------------
Expand All @@ -343,4 +345,4 @@ Weighting Schemas
Neural-Tangent-Kernel Weighting <weighting/ntk_weighting.rst>
No Weighting <weighting/no_weighting.rst>
Scalar Weighting <weighting/scalar_weighting.rst>
Self-Adaptive Weighting <weighting/self_adaptive_weighting.rst>
Self-Adaptive Weighting <weighting/self_adaptive_weighting.rst>
10 changes: 10 additions & 0 deletions docs/source/_rst/loss/sinkhorn_loss.rst
Comment thread
guglielmopadula marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Sinkhorn Loss
===============
.. currentmodule:: pina.loss.sinkhorn_loss

.. automodule:: pina._src.loss.sinkhorn_loss
:no-members:

.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss
:members:
:show-inheritance:
116 changes: 116 additions & 0 deletions pina/_src/loss/sinkhorn_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Module for the SinkhornLoss class."""

import torch
from pina._src.loss.base_dual_loss import BaseDualLoss
from pina._src.core.utils import check_consistency, check_positive_integer


class SinkhornLoss(BaseDualLoss):
r"""
Implementation of the Sinkhorn Loss based on regularized optimal transport.
It measures the regularized Wasserstein distance between the empirical
distributions represented by ``input`` (with :math:`N` samples) and
``target`` (with :math:`M` samples), each in :math:`\mathbb{R}^D`.

The loss solves the entropy-regularized optimal transport problem:

.. math::
W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)}
\langle C, \pi \rangle - \varepsilon H(\pi),

where :math:`C_{ij} = \|x_i - y_j\|_2^p` is the cost matrix,
:math:`H(\pi) = -\sum_{ij} \pi_{ij} \log \pi_{ij}` is the entropy of
the transport plan, and :math:`\varepsilon > 0` is the regularization
strength. The dual objective recovered by the Sinkhorn iterations is:

.. math::
W_\varepsilon = \langle a, f^* \rangle + \langle b, g^* \rangle,

where :math:`a` and :math:`b` are uniform probability weights over the
:math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are
the optimal dual potentials computed via log-space Sinkhorn iterations.


.. note::
Comment thread
guglielmopadula marked this conversation as resolved.
Unlike pointwise losses, the Sinkhorn loss operates on entire empirical
distributions, so the output is always a scalar regardless of the
number of samples.

.. note::
Smaller values of ``eps`` approximate the true Wasserstein distance
more closely but may require more iterations to converge.

.. seealso::

**Original reference:** Patrini, G., Berg, R., Forre, P., Carlin, M.,
and Bhatt, W. *Sinkhorn AutoEncoders*. In Proceedings of the Thirty-Fifth
Conference on Uncertainty in Artificial Intelligence (UAI), 2019.
`arXiv:1810.01118 <https://arxiv.org/abs/1810.01118>`_.
"""

def __init__(self, p=2, eps=0.1, max_iter=100):
"""
Initialization of the :class:`SinkhornLoss` class.

:param int p: Exponent of the cost function. Default is ``2``.
:param eps: Entropy regularization strength
:type eps: int | float
:math:`\varepsilon > 0`. Default is ``0.1``.
:param int max_iter: Number of Sinkhorn iterations. Default is ``100``.
:raises ValueError: If ``p`` is not a numeric value.
:raises ValueError: If ``eps`` is not a positive number.
:raises AssertionError: If ``max_iter`` is not a strictly positive int.
"""
super().__init__(reduction="mean")

check_consistency(p, (int, float))
check_consistency(eps, (int, float))
if eps <= 0:
raise ValueError(
f"eps must be a strictly positive float, got {eps}."
)
check_positive_integer(max_iter, strict=True)

self.p = p
self.eps = eps
self.max_iter = max_iter

def forward(self, input, target):
"""
Forward method of the loss function.

:param torch.Tensor input: Input tensor of shape :math:`(N, D)`.
:param torch.Tensor target: Target tensor of shape :math:`(M, D)`.
:return: Sinkhorn loss value.
:rtype: torch.Tensor
"""
n = input.shape[0]
m = target.shape[0]

a = input.new_full((n,), 1.0 / n)
b = target.new_full((m,), 1.0 / m)
f = torch.zeros(n, dtype=input.dtype, device=input.device)
g = torch.zeros(m, dtype=target.dtype, device=target.device)


# Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M)
diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D)
C = torch.linalg.norm(diff, ord=2, dim=-1).pow(self.p) # (N, M)

# Log-space Sinkhorn iterations for numerical stability
log_a = a.log()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a and b are never used. Please, consider defining directly the log instances

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep a,b snd not pass the log space (or if we pass the log space maybe add a bolean Flag log_space_opt)?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, working in log space allows us to use logsumexp, which is generally more numerically stable. I would avoid exposing this as a flag: the two approaches are mathematically equivalent, and the difference is purely implementation-related.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables a and b are used in the currently line 124.

log_b = b.log()


for _ in range(self.max_iter):
# f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps))
f = self.eps * (
log_a - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1)
)
# g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps))
g = self.eps * (
log_b - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0)
)

loss = (a * f).sum() + (b * g).sum()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this loss can take negative values. Please confirm whether this is expected behavior; otherwise, adjust the implementation accordingly.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss can indeed take negative values.

return self._reduction(loss.unsqueeze(0))
2 changes: 2 additions & 0 deletions pina/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"BaseDualLoss",
"LpLoss",
"PowerLoss",
"SinkhornLoss",
]

from pina._src.loss.dual_loss_interface import DualLossInterface
from pina._src.loss.base_dual_loss import BaseDualLoss
from pina._src.loss.power_loss import PowerLoss
from pina._src.loss.lp_loss import LpLoss
from pina._src.loss.sinkhorn_loss import SinkhornLoss

# Back-compatibility with version 0.2, to be removed soon
import warnings
Expand Down
73 changes: 73 additions & 0 deletions tests/test_loss/test_sinkhorn_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import pytest

from pina.loss import SinkhornLoss

# Fixed random tensors for reproducibility
torch.manual_seed(0)
input_ = torch.rand(10, 2)
target_ = torch.rand(8, 2)


@pytest.mark.parametrize("p", [1, 2])
@pytest.mark.parametrize("eps", [0.01, 0.1])
@pytest.mark.parametrize("max_iter", [10, 100])
def test_constructor(p, eps, max_iter):

SinkhornLoss(p=p, eps=eps, max_iter=max_iter)

# Should fail if p is not numeric
with pytest.raises(ValueError):
SinkhornLoss(
p="invalid", eps=eps, max_iter=max_iter
)

# Should fail if eps is not numeric
with pytest.raises(ValueError):
SinkhornLoss(p=p, eps="bad", max_iter=max_iter)


# Should fail if eps is not positive
with pytest.raises(ValueError):
SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter)

# Should fail if max_iter is not a positive integer
with pytest.raises(AssertionError):
SinkhornLoss(p=p, eps=eps, max_iter=0)



@pytest.mark.parametrize("p", [1, 2, 3])
@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0, 1])
@pytest.mark.parametrize("max_iter", [10, 100])
@pytest.mark.parametrize("input, target", [
(torch.rand(10, 2), torch.rand(8, 2)), # different N, M; 2D features
(torch.rand(5, 3), torch.rand(5, 3)), # same N=M; 3D features
(torch.rand(1, 4), torch.rand(7, 4)), # single input sample
(torch.rand(6, 4), torch.rand(1, 4)), # single target sample
(torch.rand(3, 1), torch.rand(4, 1)), # 1D feature space
(input_, input_), # identical distributions
(input_, target_), # different distributions
])
def test_forward(p, eps, max_iter, input, target):
# Output must always be a scalar and finite. The (non-debiased) Sinkhorn
# dual can be negative due to the entropy regularization term; this is
# expected behavior. eps can be an integer as well as a float.
loss_fn = SinkhornLoss(p=p, eps=eps, max_iter=max_iter)
value = loss_fn(input, target)
assert value.shape == torch.Size([1])
assert torch.isfinite(value).all()
# Sinkhorn loss on identical data should be smaller than on different data.
if input is input_ and target is target_:
loss_same = SinkhornLoss(p=p, eps=eps, max_iter=500)(input_, input_)
assert loss_same.item() < value.item()


def test_forward_approaches_wasserstein():

# For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N
x = torch.tensor([[1.0], [2.0], [3.0]])
y = torch.tensor([[4.0], [5.0], [6.0]])
# W_2^2 = ((1-4)^2 + (2-5)^2 + (3-6)^2) / 3 = 9
value = SinkhornLoss(p=2, eps=1e-3, max_iter=5000)(x, y)
assert abs(value.item() - 9.0) < 0.1
Loading