sup3r.utilities.loss_metrics.SlicedWassersteinLoss

sup3r.utilities.loss_metrics.SlicedWassersteinLoss#

class SlicedWassersteinLoss(n_projections=1024)[source]#

Bases: Sup3rLoss

Loss class for sliced wasserstein distance loss

Parameters:
  • n_projections (int) – number of random 1D projections to use

  • Note

  • —-

  • Experimentally, we get stability in the SW metric when n_projections

  • is at least 30% of the number of projection dimensions, which for us

  • is HWT. This might be computationally expensive for large

  • spatial/temporal sizes so we default to 1024.

Methods

call(x_true, x_gen)

Sliced Wasserstein distance based on random 1D projections

from_config(config)

get_config()

Attributes

dtype

call(x_true, x_gen)[source]#

Sliced Wasserstein distance based on random 1D projections

Parameters:
  • x_true (tf.tensor) – high resolution ground truth data (n_observations, spatial_1, spatial_2, temporal, features)

  • x_gen (tf.tensor) – synthetic generator output (n_observations, spatial_1, spatial_2, temporal, features)

Returns:

tf.tensor – 0D tensor loss value

__call__(y_true, y_pred, sample_weight=None)#

Call self as a function.