sup3r.preprocessing.samplers.dual.DualSampler#
- class DualSampler(data: Sup3rDataset, sample_shape: tuple | None = None, batch_size: int = 16, s_enhance: int = 1, t_enhance: int = 1, feature_sets: dict | None = None, proxy_obs_kwargs: dict | None = None, mode: str = 'lazy')[source]#
Bases:
SamplerSampler for sampling from paired (or dual) datasets. Pairs consist of low and high resolution data, which are contained by a Sup3rDataset. This can also include extra observation data on the same grid as the high-resolution data which has NaNs at points where observation data doesn’t exist. This will be used in an additional content loss term.
- Parameters:
data (Sup3rDataset) – A
Sup3rDatasetinstance with low-res and high-res data members.sample_shape (tuple) – Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement factors.
s_enhance (int) – Spatial enhancement factor
t_enhance (int) – Temporal enhancement factor
feature_sets (Optional[dict]) – Optional dictionary describing how the full set of features is split between
lr_features,hr_exo_features, andhr_out_features.- lr_featureslist | tuple
List of feature names or patt*erns to use as low-resolution model inputs. If no entry is provided then all available features from the data will be used.
- hr_out_featureslist | tuple
List of feature names or patt*erns that should be output by the generative model and available as ground truth targets. If no entry is provided then all features in the high res data will be used.
- hr_exo_featureslist | tuple
List of feature names or patt*erns that should be available as high-resolution model inputs (like topography or observations) or bespoke loss functions. Features used for input are injected into the model mid-network to condition output on high-resolution information. The model configuration should have the appropriate layers to use these features. e.g.
Sup3rConcatfor topography injection,Sup3rObsModelorSup3rCrossAttentionfor obs injection. If no entry is provided then hr_exo_features will be empty.
*To include sparse features as inputs or targets the features must have an “_obs” suffix.
proxy_obs_kwargs (dict | None) – Optional dictionary of keyword arguments to pass to the proxy observation generator. This is only used when training with proxy observations. Keys can include
onshore_obs_frac,offshore_obs_frac, andperturbation_scale.- perturbation_scalefloat
Scale of the perturbation to add to the proxy observations when using proxy observations. This specifies the multiplier of the noise sampled from (-standard deviation, standard deviation). The standdard deviation is calculated per feature over each batch.
- onshore_obs_fracfloat | dict
Fraction of onshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
- offshore_obs_fracfloat | dict
Fraction of offshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
mode (str) – Mode for sampling data. Options are ‘lazy’ or ‘eager’. ‘eager’ mode pre-loads all data into memory as numpy arrays for faster access. ‘lazy’ mode samples directly from the underlying data object, which could be backed by dask arrays or on-disk netCDF files.
Methods
Make sure features are consistent with the data and with each other.
Check that the obs features are configured correctly for proxy observations.
Make sure container shapes are compatible with enhancement factors.
get_sample_index([n_obs])Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal extent.
post_init_log([args_dict])Log additional arguments after initialization.
Perform shape and feature checks.
wrap(data)Return a
Sup3rDatasetobject or tuple of such.Attributes
timerReturn underlying data.
Get a list of exogenous high-resolution features that are only used for training e.g., mid-network high-res topo injection.
List of feature names or patt*erns that the model is shown at high-resolution.
Get the high-resolution feature channel indices that should be included for loss calculations.
List of feature names or patt*erns that should be output by the generative model.
Shape of the data sample to select when __next__() is called.
Features available natively at high-resolution.
List of feature names or patt*erns to use as low-resolution model inputs.
Get the low-resolution feature channel indices that should be included for training.
List of feature names or patt*erns that should be treated as observations.
Get the source feature indices in
featuresfor each obs feature.Fraction of offshore observations to include in each batch when using proxy observations.
Fraction of onshore observations to include in each batch when using proxy observations.
Scale of the perturbation to add to the proxy observations when using proxy observations.
Shape of the data sample to select when
__next__()is called.Get shape of underlying data.
Whether to use proxy observations.
- property data#
Return underlying data.
- Returns:
See also
- property hr_source_features#
Features available natively at high-resolution.
- check_feature_consistency()[source]#
Make sure features are consistent with the data and with each other.
- check_shape_consistency()[source]#
Make sure container shapes are compatible with enhancement factors.
- get_sample_index(n_obs=None)[source]#
Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal extent. Optionally includes an extra high res index if the sample data includes observation data.
- check_proxy_obs_consistency()#
Check that the obs features are configured correctly for proxy observations.
- property hr_exo_features#
Get a list of exogenous high-resolution features that are only used for training e.g., mid-network high-res topo injection. These must come at the end of the high-res feature set. These can also be input to the model as low-res features.
- property hr_features#
List of feature names or patt*erns that the model is shown at high-resolution. This does not include features that are only shown to the model after coarsening. Thus, this includes hr_out_features and and hr_exo_features.
- property hr_features_ind#
Get the high-resolution feature channel indices that should be included for loss calculations. This includes hr_out_features and hr_exo_features, Any high-resolution features that are only included in the data handler to be coarsened for the low-res input are removed.
- property hr_out_features#
List of feature names or patt*erns that should be output by the generative model. If no entry is provided then all features in hr_features will be used.
- property hr_sample_shape: tuple#
Shape of the data sample to select when __next__() is called. Same as sample_shape
- property lr_features#
List of feature names or patt*erns to use as low-resolution model inputs. If no entry is provided then all available features from the data will be used.
- property lr_features_ind#
Get the low-resolution feature channel indices that should be included for training. This includes lr_features.
- property obs_features#
List of feature names or patt*erns that should be treated as observations. These features will be included in the high-res data but not the low-res data and won’t necessarily be expected to be output by the generative model. These are different from other hr_exo_features in that they are intended to be used as observation features with NaN values where observations are not available.
- property obs_features_ind#
Get the source feature indices in
featuresfor each obs feature. Each obs feature named<feature>_obsmaps to the corresponding<feature>in the features.- Returns:
list[int] – Indices into
featuresfor each obs feature source.
- property offshore_obs_frac#
Fraction of offshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
- property onshore_obs_frac#
Fraction of onshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
- property perturbation_scale#
Scale of the perturbation to add to the proxy observations when using proxy observations. This specifies the multiplier of the noise sampled from (-standard deviation, standard deviation).
- post_init_log(args_dict=None)#
Log additional arguments after initialization.
- preflight()#
Perform shape and feature checks.
- property shape#
Get shape of underlying data.
- property use_proxy_obs#
Whether to use proxy observations. When True, proxy observation features are generated by masking the corresponding gridded ground truth data and are appended to the samples. The obs features are specified by the
obs_featuresargument and should have a corresponding source feature in the data features that is used for sampling. For example, an obs feature namedtemperature_obswould be generated from the gridded ground truth feature namedtemperature.
- wrap(data)#
Return a
Sup3rDatasetobject or tuple of such. This is a tuple when the.dataattribute belongs to aCollectionobject likeBatchHandler. Otherwise this isSup3rDatasetobject, which is either a wrapped 3-tuple, 2-tuple, or 1-tuple (e.g.len(data) == 3,len(data) == 2orlen(data) == 1). This is a 3-tuple when.databelongs to a container object likeDualSamplerWithObs, a 2-tuple when.databelongs to a dual container object likeDualSampler, and a 1-tuple otherwise.