From e5f50f4b4ee37d5717fdafe05e5d803831d05ddf Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 17 Jun 2026 12:10:43 +0200 Subject: [PATCH 1/5] add sinkhorn loss --- docs/source/_rst/loss/sinkhorn_loss.rst | 10 ++ pina/_src/loss/sinkhorn_loss.py | 127 ++++++++++++++++++++++++ pina/loss/__init__.py | 2 + tests/test_loss/test_sinkhorn_loss.py | 83 ++++++++++++++++ 4 files changed, 222 insertions(+) create mode 100644 docs/source/_rst/loss/sinkhorn_loss.rst create mode 100644 pina/_src/loss/sinkhorn_loss.py create mode 100644 tests/test_loss/test_sinkhorn_loss.py diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst new file mode 100644 index 000000000..d997c3ec3 --- /dev/null +++ b/docs/source/_rst/loss/sinkhorn_loss.rst @@ -0,0 +1,10 @@ +Lp Loss +=============== +.. currentmodule:: pina.loss.sinkhorn_loss + +.. automodule:: pina._src.loss.sinkhorn_loss + :no-members: + +.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss + :members: + :show-inheritance: diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py new file mode 100644 index 000000000..2eb226451 --- /dev/null +++ b/pina/_src/loss/sinkhorn_loss.py @@ -0,0 +1,127 @@ +"""Module for the SinkhornLoss class.""" + +import torch +from pina._src.loss.base_dual_loss import BaseDualLoss +from pina._src.core.utils import check_consistency, check_positive_integer + + +class SinkhornLoss(BaseDualLoss): + r""" + Implementation of the Sinkhorn Loss based on regularized optimal transport. + It measures the regularized Wasserstein distance between the empirical + distributions represented by ``input`` (with :math:`N` samples) and + ``target`` (with :math:`M` samples), each in :math:`\mathbb{R}^D`. + + The loss solves the entropy-regularized optimal transport problem: + + .. math:: + W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} + \langle C, \pi \rangle - \varepsilon H(\pi), + + where :math:`C_{ij} = \|x_i - y_j\|_2^p` is the cost matrix, + :math:`H(\pi) = -\sum_{ij} \pi_{ij} \log \pi_{ij}` is the entropy of + the transport plan, and :math:`\varepsilon > 0` is the regularization + strength. The dual objective recovered by the Sinkhorn iterations is: + + .. math:: + W_\varepsilon = \langle a, f^* \rangle + \langle b, g^* \rangle, + + where :math:`a` and :math:`b` are uniform probability weights over the + :math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are + the optimal dual potentials computed via log-space Sinkhorn iterations. + + If ``reduction`` is set to ``"mean"`` or ``"sum"``, the scalar transport + cost is aggregated accordingly (the output is always a scalar, so both + reductions are equivalent): + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\ + \operatorname{sum}(L), & \text{if reduction} = \text{``sum''} + \end{cases} + + .. note:: + Unlike pointwise losses, the Sinkhorn loss operates on entire empirical + distributions, so the output is always a scalar regardless of the + number of samples. The ``reduction`` parameter is retained for API + consistency. + + .. note:: + Smaller values of ``eps`` approximate the true Wasserstein distance + more closely but may require more iterations to converge. + + .. note:: + The algorithm is taken from "Sinkhorn AutoEncoders", arXiv:1810.01118. + """ + + def __init__(self, p=2, eps=0.1, max_iter=100, reduction="mean"): + """ + Initialization of the :class:`SinkhornLoss` class. + + :param int p: Exponent of the cost function :math:`\|x_i - y_j\|_2^p`. + Default is ``2``. + :param float eps: Entropy regularization strength + :math:`\varepsilon > 0`. Larger values yield smoother transport + plans. Default is ``0.1``. + :param int max_iter: Number of Sinkhorn iterations. Default is ``100``. + :param str reduction: The reduction method to aggregate the scalar loss. + Available options include: ``"none"``, ``"mean"``, ``"sum"``. + Default is ``"mean"``. + :raises ValueError: If ``p`` is not a numeric value. + :raises ValueError: If ``eps`` is not a positive float. + :raises AssertionError: If ``max_iter`` is not a strictly positive int. + """ + super().__init__(reduction=reduction) + + check_consistency(p, (int, float)) + check_consistency(eps, float) + if eps <= 0: + raise ValueError( + f"eps must be a strictly positive float, got {eps}." + ) + check_positive_integer(max_iter, strict=True) + + self.p = p + self.eps = eps + self.max_iter = max_iter + + def forward(self, input, target): + """ + Forward method of the loss function. + + :param torch.Tensor input: Input tensor of shape :math:`(N, D)`. + :param torch.Tensor target: Target tensor of shape :math:`(M, D)`. + :return: Sinkhorn loss value. + :rtype: torch.Tensor + """ + n = input.shape[0] + m = target.shape[0] + + a = input.new_full((n,), 1.0 / n) + b = target.new_full((m,), 1.0 / m) + + # Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M) + diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D) + C = torch.linalg.norm(diff, ord=2, dim=-1).pow(self.p) # (N, M) + + # Log-space Sinkhorn iterations for numerical stability + log_a = a.log() + log_b = b.log() + f = torch.zeros(n, dtype=input.dtype, device=input.device) + g = torch.zeros(m, dtype=target.dtype, device=target.device) + + for _ in range(self.max_iter): + # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) + f = self.eps * ( + log_a + - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) + ) + # g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps)) + g = self.eps * ( + log_b + - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) + ) + + loss = (a * f).sum() + (b * g).sum() + return self._reduction(loss.unsqueeze(0)) diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 52ed278c7..7966d2019 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -5,12 +5,14 @@ "BaseDualLoss", "LpLoss", "PowerLoss", + "SinkhornLoss" ] from pina._src.loss.dual_loss_interface import DualLossInterface from pina._src.loss.base_dual_loss import BaseDualLoss from pina._src.loss.power_loss import PowerLoss from pina._src.loss.lp_loss import LpLoss +from pina._src.loss.sinkhorn_loss import SinkhornLoss # Back-compatibility with version 0.2, to be removed soon import warnings diff --git a/tests/test_loss/test_sinkhorn_loss.py b/tests/test_loss/test_sinkhorn_loss.py new file mode 100644 index 000000000..40e647596 --- /dev/null +++ b/tests/test_loss/test_sinkhorn_loss.py @@ -0,0 +1,83 @@ +import torch +import pytest + +from pina.loss import SinkhornLoss + +# Fixed random tensors for reproducibility +torch.manual_seed(0) +input_ = torch.rand(10, 2) +target_ = torch.rand(8, 2) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0]) +@pytest.mark.parametrize("max_iter", [10, 100]) +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_constructor(p, eps, max_iter, reduction): + + SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction=reduction) + + # Should fail if p is not numeric + with pytest.raises(ValueError): + SinkhornLoss(p="invalid", eps=eps, max_iter=max_iter, reduction=reduction) + + # Should fail if eps is not a float + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=1, max_iter=max_iter, reduction=reduction) + + # Should fail if eps is not positive + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter, reduction=reduction) + + # Should fail if max_iter is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=p, eps=eps, max_iter=0, reduction=reduction) + + # Should fail if reduction is invalid + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction="invalid") + + +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_forward_shape(reduction): + + loss_fn = SinkhornLoss(reduction=reduction) + value = loss_fn(input_, target_) + assert value.shape == torch.Size([1]) + + +def test_forward_finite(): + + # The (non-debiased) Sinkhorn dual can be negative due to the entropy + # regularization term, but it must always be finite. + loss_fn = SinkhornLoss() + value = loss_fn(input_, target_) + assert torch.isfinite(value).all() + + +def test_forward_same_distribution_smaller(): + + # Sinkhorn loss on identical data should be smaller than on different data + loss_same = SinkhornLoss(eps=1e-3, max_iter=500)(input_, input_) + loss_diff = SinkhornLoss(eps=1e-3, max_iter=500)(input_, target_) + assert loss_same.item() < loss_diff.item() + + +def test_forward_asymmetric_sizes(): + + # input and target may have different numbers of rows + x = torch.rand(5, 3) + y = torch.rand(8, 3) + value = SinkhornLoss()(x, y) + assert value.shape == torch.Size([1]) + assert torch.isfinite(value).all() + + +def test_forward_approaches_wasserstein(): + + # For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N + x = torch.tensor([[1.0], [2.0], [3.0]]) + y = torch.tensor([[4.0], [5.0], [6.0]]) + # W_2^2 = ((1-4)^2 + (2-5)^2 + (3-6)^2) / 3 = 9 + value = SinkhornLoss(p=2, eps=1e-3, max_iter=5000)(x, y) + assert abs(value.item() - 9.0) < 0.1 From 9bcaef412cdf1a1e687cdb82dceb1a2205ddae06 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 17 Jun 2026 12:14:45 +0200 Subject: [PATCH 2/5] black formatter --- pina/_src/loss/sinkhorn_loss.py | 6 ++---- pina/loss/__init__.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py index 2eb226451..f81d022ac 100644 --- a/pina/_src/loss/sinkhorn_loss.py +++ b/pina/_src/loss/sinkhorn_loss.py @@ -114,13 +114,11 @@ def forward(self, input, target): for _ in range(self.max_iter): # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) f = self.eps * ( - log_a - - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) + log_a - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) ) # g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps)) g = self.eps * ( - log_b - - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) + log_b - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) ) loss = (a * f).sum() + (b * g).sum() diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 7966d2019..280cbf76a 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -5,7 +5,7 @@ "BaseDualLoss", "LpLoss", "PowerLoss", - "SinkhornLoss" + "SinkhornLoss", ] from pina._src.loss.dual_loss_interface import DualLossInterface From 2ad35c1c12bf06483c453e7f9912147666f27b1c Mon Sep 17 00:00:00 2001 From: cyberguli Date: Thu, 18 Jun 2026 16:48:29 +0200 Subject: [PATCH 3/5] fixing docs, merging tests and removing reduction --- docs/source/_rst/loss/sinkhorn_loss.rst | 2 +- pina/_src/loss/sinkhorn_loss.py | 41 +++++-------- tests/test_loss/test_sinkhorn_loss.py | 82 +++++++++++-------------- 3 files changed, 51 insertions(+), 74 deletions(-) diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst index d997c3ec3..6305766a1 100644 --- a/docs/source/_rst/loss/sinkhorn_loss.rst +++ b/docs/source/_rst/loss/sinkhorn_loss.rst @@ -1,4 +1,4 @@ -Lp Loss +Sinkhorn Loss =============== .. currentmodule:: pina.loss.sinkhorn_loss diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py index f81d022ac..931355a4e 100644 --- a/pina/_src/loss/sinkhorn_loss.py +++ b/pina/_src/loss/sinkhorn_loss.py @@ -30,23 +30,12 @@ class SinkhornLoss(BaseDualLoss): :math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are the optimal dual potentials computed via log-space Sinkhorn iterations. - If ``reduction`` is set to ``"mean"`` or ``"sum"``, the scalar transport - cost is aggregated accordingly (the output is always a scalar, so both - reductions are equivalent): - - .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\ - \operatorname{sum}(L), & \text{if reduction} = \text{``sum''} - \end{cases} .. note:: Unlike pointwise losses, the Sinkhorn loss operates on entire empirical distributions, so the output is always a scalar regardless of the - number of samples. The ``reduction`` parameter is retained for API - consistency. - + number of samples. + .. note:: Smaller values of ``eps`` approximate the true Wasserstein distance more closely but may require more iterations to converge. @@ -55,27 +44,23 @@ class SinkhornLoss(BaseDualLoss): The algorithm is taken from "Sinkhorn AutoEncoders", arXiv:1810.01118. """ - def __init__(self, p=2, eps=0.1, max_iter=100, reduction="mean"): + def __init__(self, p=2, eps=0.1, max_iter=100): """ Initialization of the :class:`SinkhornLoss` class. - :param int p: Exponent of the cost function :math:`\|x_i - y_j\|_2^p`. - Default is ``2``. - :param float eps: Entropy regularization strength - :math:`\varepsilon > 0`. Larger values yield smoother transport - plans. Default is ``0.1``. + :param int p: Exponent of the cost function. Default is ``2``. + :param eps: Entropy regularization strength + :type eps: int | float + :math:`\varepsilon > 0`. Default is ``0.1``. :param int max_iter: Number of Sinkhorn iterations. Default is ``100``. - :param str reduction: The reduction method to aggregate the scalar loss. - Available options include: ``"none"``, ``"mean"``, ``"sum"``. - Default is ``"mean"``. :raises ValueError: If ``p`` is not a numeric value. - :raises ValueError: If ``eps`` is not a positive float. + :raises ValueError: If ``eps`` is not a positive number. :raises AssertionError: If ``max_iter`` is not a strictly positive int. """ - super().__init__(reduction=reduction) + super().__init__(reduction="mean") check_consistency(p, (int, float)) - check_consistency(eps, float) + check_consistency(eps, (int, float)) if eps <= 0: raise ValueError( f"eps must be a strictly positive float, got {eps}." @@ -100,6 +85,9 @@ def forward(self, input, target): a = input.new_full((n,), 1.0 / n) b = target.new_full((m,), 1.0 / m) + f = torch.zeros(n, dtype=input.dtype, device=input.device) + g = torch.zeros(m, dtype=target.dtype, device=target.device) + # Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M) diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D) @@ -108,8 +96,7 @@ def forward(self, input, target): # Log-space Sinkhorn iterations for numerical stability log_a = a.log() log_b = b.log() - f = torch.zeros(n, dtype=input.dtype, device=input.device) - g = torch.zeros(m, dtype=target.dtype, device=target.device) + for _ in range(self.max_iter): # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) diff --git a/tests/test_loss/test_sinkhorn_loss.py b/tests/test_loss/test_sinkhorn_loss.py index 40e647596..a28dcea5f 100644 --- a/tests/test_loss/test_sinkhorn_loss.py +++ b/tests/test_loss/test_sinkhorn_loss.py @@ -9,70 +9,60 @@ target_ = torch.rand(8, 2) -@pytest.mark.parametrize("p", [1, 2, 3]) -@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0]) +@pytest.mark.parametrize("p", [1, 2]) +@pytest.mark.parametrize("eps", [0.01, 0.1]) @pytest.mark.parametrize("max_iter", [10, 100]) -@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) -def test_constructor(p, eps, max_iter, reduction): +def test_constructor(p, eps, max_iter): - SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction=reduction) + SinkhornLoss(p=p, eps=eps, max_iter=max_iter) # Should fail if p is not numeric with pytest.raises(ValueError): - SinkhornLoss(p="invalid", eps=eps, max_iter=max_iter, reduction=reduction) + SinkhornLoss( + p="invalid", eps=eps, max_iter=max_iter + ) - # Should fail if eps is not a float + # Should fail if eps is not numeric with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=1, max_iter=max_iter, reduction=reduction) + SinkhornLoss(p=p, eps="bad", max_iter=max_iter) + # Should fail if eps is not positive with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter, reduction=reduction) + SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter) # Should fail if max_iter is not a positive integer with pytest.raises(AssertionError): - SinkhornLoss(p=p, eps=eps, max_iter=0, reduction=reduction) - - # Should fail if reduction is invalid - with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction="invalid") - + SinkhornLoss(p=p, eps=eps, max_iter=0) -@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) -def test_forward_shape(reduction): - loss_fn = SinkhornLoss(reduction=reduction) - value = loss_fn(input_, target_) - assert value.shape == torch.Size([1]) - - -def test_forward_finite(): - - # The (non-debiased) Sinkhorn dual can be negative due to the entropy - # regularization term, but it must always be finite. - loss_fn = SinkhornLoss() - value = loss_fn(input_, target_) - assert torch.isfinite(value).all() - -def test_forward_same_distribution_smaller(): - - # Sinkhorn loss on identical data should be smaller than on different data - loss_same = SinkhornLoss(eps=1e-3, max_iter=500)(input_, input_) - loss_diff = SinkhornLoss(eps=1e-3, max_iter=500)(input_, target_) - assert loss_same.item() < loss_diff.item() - - -def test_forward_asymmetric_sizes(): - - # input and target may have different numbers of rows - x = torch.rand(5, 3) - y = torch.rand(8, 3) - value = SinkhornLoss()(x, y) +@pytest.mark.parametrize("p", [1, 2, 3]) +@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0, 1]) +@pytest.mark.parametrize("max_iter", [10, 100]) +@pytest.mark.parametrize("input, target", [ + (torch.rand(10, 2), torch.rand(8, 2)), # different N, M; 2D features + (torch.rand(5, 3), torch.rand(5, 3)), # same N=M; 3D features + (torch.rand(1, 4), torch.rand(7, 4)), # single input sample + (torch.rand(6, 4), torch.rand(1, 4)), # single target sample + (torch.rand(3, 1), torch.rand(4, 1)), # 1D feature space + (input_, input_), # identical distributions + (input_, target_), # different distributions +]) +def test_forward(p, eps, max_iter, input, target): + # Output must always be a scalar and finite. The (non-debiased) Sinkhorn + # dual can be negative due to the entropy regularization term; this is + # expected behavior. eps can be an integer as well as a float. + loss_fn = SinkhornLoss(p=p, eps=eps, max_iter=max_iter) + value = loss_fn(input, target) assert value.shape == torch.Size([1]) assert torch.isfinite(value).all() - - + # Sinkhorn loss on identical data should be smaller than on different data. + if input is input_ and target is target_: + loss_same = SinkhornLoss(p=p, eps=eps, max_iter=500)(input_, input_) + assert loss_same.item() < value.item() + + def test_forward_approaches_wasserstein(): # For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N From da65e7455b2629b408cc2e42d8af7276c3238568 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Thu, 18 Jun 2026 16:51:24 +0200 Subject: [PATCH 4/5] updated _code.rst --- docs/source/_rst/_code.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 0c289183e..ecd50ec7d 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -330,6 +330,8 @@ Losses BaseDualLoss LpLoss PowerLoss + SinkhornLoss + Weighting Schemas -------------------- @@ -343,4 +345,4 @@ Weighting Schemas Neural-Tangent-Kernel Weighting No Weighting Scalar Weighting - Self-Adaptive Weighting \ No newline at end of file + Self-Adaptive Weighting From f444a1010601cd9a1f43307c5f1f2ff9f44f905c Mon Sep 17 00:00:00 2001 From: cyberguli Date: Thu, 18 Jun 2026 16:55:53 +0200 Subject: [PATCH 5/5] updated paper --- pina/_src/loss/sinkhorn_loss.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py index 931355a4e..a7da86217 100644 --- a/pina/_src/loss/sinkhorn_loss.py +++ b/pina/_src/loss/sinkhorn_loss.py @@ -40,8 +40,12 @@ class SinkhornLoss(BaseDualLoss): Smaller values of ``eps`` approximate the true Wasserstein distance more closely but may require more iterations to converge. - .. note:: - The algorithm is taken from "Sinkhorn AutoEncoders", arXiv:1810.01118. + .. seealso:: + + **Original reference:** Patrini, G., Berg, R., Forre, P., Carlin, M., + and Bhatt, W. *Sinkhorn AutoEncoders*. In Proceedings of the Thirty-Fifth + Conference on Uncertainty in Artificial Intelligence (UAI), 2019. + `arXiv:1810.01118 `_. """ def __init__(self, p=2, eps=0.1, max_iter=100):