sup3r.utilities.loss_metrics.tf_derivative

Contents

sup3r.utilities.loss_metrics.tf_derivative#

tf_derivative(x, axis=1)[source]#

Custom derivative function for compatibility with tensorflow.

Note

Matches np.gradient by using the central difference approximation.

Parameters:
  • x (tf.Tensor) – (n_observations, spatial_1, spatial_2, temporal)

  • axis (int) – Axis to take derivative over