sup3r.utilities.loss_metrics.SlicedWassersteinLoss#
- class SlicedWassersteinLoss(n_projections=1024)[source]#
Bases:
Sup3rLossLoss class for sliced wasserstein distance loss
- Parameters:
n_projections (int) – number of random 1D projections to use
Note
—-
Experimentally, we get stability in the SW metric when n_projections
is at least 30% of the number of projection dimensions, which for us
is HWT. This might be computationally expensive for large
spatial/temporal sizes so we default to 1024.
Methods
call(x_true, x_gen)Sliced Wasserstein distance based on random 1D projections
from_config(config)get_config()Attributes
dtype- call(x_true, x_gen)[source]#
Sliced Wasserstein distance based on random 1D projections
- Parameters:
x_true (tf.tensor) – high resolution ground truth data (n_observations, spatial_1, spatial_2, temporal, features)
x_gen (tf.tensor) – synthetic generator output (n_observations, spatial_1, spatial_2, temporal, features)
- Returns:
tf.tensor – 0D tensor loss value
- __call__(y_true, y_pred, sample_weight=None)#
Call self as a function.