sup3r.preprocessing.batch_handlers.factory.BatchHandlerMom1#
- class BatchHandlerMom1(train_containers, *, val_containers=None, sample_shape=None, batch_size=16, n_batches=64, s_enhance=1, t_enhance=1, means=None, stds=None, queue_cap=None, transform_kwargs=None, mode='lazy', feature_sets=None, proxy_obs_kwargs=None, time_enhance_mode='constant', lower_models=None, s_padding=0, t_padding=0, end_t_padding=False, verbose=False, **kwargs)#
Bases:
QueueMom1BatchHandler object built from two lists of
Containerobjects, one with training data and one with validation data. These lists will be used to initialize lists of class:Sampler objects that will then be used to build batches at run time.Notes
These lists of containers can contain data from the same underlying data source (e.g. CONUS WTK) (e.g. initialize train / val containers with different time period and / or regions, or they can be used to sample from completely different data sources (e.g. train on CONUS WTK while validating on Canada WTK).
See also
Sampler,AbstractBatchQueue,StatsCollection- Parameters:
train_containers (list[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of training data. The data can be a Sup3rX or Sup3rDataset object.
val_containers (list[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of validation data. The data can be a Sup3rX or a Sup3rDataset object.
batch_size (int) – Number of observations / samples in a batch
n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.
s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.
t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.
means (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.
stds (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.
queue_cap (int) – Maximum number of batches the batch queue can store.
transform_kwargs (Union[dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.
mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.
feature_sets (Optional[dict]) – Optional dictionary describing how the full set of features is split between
lr_features,hr_exo_features, andhr_out_features.- lr_featureslist | tuple
List of feature names or patt*erns to use as low-resolution model inputs. If no entry is provided then all available features from the data will be used.
- hr_out_featureslist | tuple
List of feature names or patt*erns that should be output by the generative model and available as ground truth targets. If no entry is provided then all features in lr_features will be used.
- hr_exo_featureslist | tuple
List of feature names or patt*erns that should be available as high-resolution model inputs (like topography or observations) or for bespoke loss functions. Features used as inputs are injected into the model mid-network to condition output on high-resolution information. The model configuration should have the appropriate layers to use these features. e.g.
Sup3rConcatfor topography injection,Sup3rObsModelorSup3rCrossAttentionfor obs injection. If no entry is provided then hr_exo_features will be empty.
*To include sparse features as inputs or targets the features must have an “_obs” suffix.
kwargs (dict) – Keyword arguments for parent class
sample_shape (tuple) – Size of arrays to sample from the contained data.
proxy_obs_kwargs (dict | None) – Optional dictionary of keyword arguments to pass to the proxy observation generator. This is only used when training with proxy observations. Keys can include
onshore_obs_frac,offshore_obs_frac, andperturbation_scale.- perturbation_scalefloat
Scale of the perturbation to add to the proxy observations when using proxy observations. This specifies the multiplier of the noise sampled from (-standard deviation, standard deviation). The standdard deviation is calculated per feature over each batch.
- onshore_obs_fracfloat | dict
Fraction of onshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
- offshore_obs_fracfloat | dict
Fraction of offshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.
time_enhance_mode (str) – [constant, linear] Method to enhance temporally when constructing subfilter. At every temporal location, a low-res temporal data is subtracted from the high-res temporal data predicted. constant will assume that the low-res temporal data is constant between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res.
lower_models (dict[int, Sup3rCondMom] | None) – Dictionary of models that predict lower moments. For example, if this queue is part of a handler to estimate the 3rd moment lower_models could include models that estimate the 1st and 2nd moments. These lower moments can be required in higher order moment calculations.
s_padding (int | None) – Width of spatial padding to predict only middle part. If None, no padding is used
t_padding (int | None) – Width of temporal padding to predict only middle part. If None, no padding is used
end_t_padding (bool | False) – Zero pad the end of temporal space. Ensures that loss is calculated only if snapshot is surrounded by temporal landmarks. False by default
verbose (bool) – Whether to log timing information for batch steps.
Methods
Make sure the enhancement factors evenly divide the sample_shape.
Make sure all samplers have the same sets of features.
check_shared_attr(attr)Check if all containers have the same value for attr.
Callback function for queue thread.
Get batch from queue or directly from a
Samplerthroughsample_batch.Get random container index based on weights
Return FIFO queue for storing batches.
Get random container based on container weights
init_samplers(train_containers, ...)Initialize samplers from given data containers.
Log info about queue size.
make_mask(high_res)Make mask for output.
make_output(samples)For the 1st moment the output is simply the high_res
post_init_log([args_dict])Log additional arguments after initialization.
post_proc(samples)Returns normalized collection of samples / observations along with mask and target output for conditional moment estimation.
Run checks before kicking off the queue.
Get random sampler from collection and return a batch of samples from that sampler.
sample_batches(n_batches)Sample given number of batches either in serial or with thread pool.
start()Start the val data batch queue in addition to the train batch queue.
stop()Stop the val data batch queue in addition to the train batch queue.
transform(samples[, smoothing, ...])Coarsen high res data to get corresponding low res batch.
wrap(data)Return a
Sup3rDatasetobject or tuple of such.Attributes
timerBATCH_MEMBERSGet weights used to sample from different containers based on relative sizes
Return underlying data.
Get all features contained in data.
Shape of high resolution sample in a low-res / high-res pair.
Shape of low resolution sample in a low-res / high-res pair.
Get number of scheduled futures that will eventually add batches to the queue.
Get number of batches in the queue.
Shape of objects stored in the queue.
Get new queue thread.
Boolean to check whether to keep enqueueing batches.
Get shape of underlying data.
Shapes of batches returned by
__next__- check_enhancement_factors()#
Make sure the enhancement factors evenly divide the sample_shape.
- check_features()#
Make sure all samplers have the same sets of features.
Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.
- property container_weights#
Get weights used to sample from different containers based on relative sizes
- property data#
Return underlying data.
- Returns:
See also
- enqueue_batches() None#
Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.
- property features#
Get all features contained in data.
- get_container_index()#
Get random container index based on weights
- get_queue()#
Return FIFO queue for storing batches.
- get_random_container()#
Get random container based on container weights
- property hr_shape#
Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))
- init_samplers(train_containers, val_containers, sample_shape, feature_sets, batch_size, mode, sampler_kwargs)#
Initialize samplers from given data containers.
- log_queue_info()#
Log info about queue size.
- property lr_shape#
Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))
- make_mask(high_res)#
Make mask for output. This is used to ensure consistency when training conditional moments.
Note
Consider the case of learning E(HR|LR) where HR is the high_res and LR is the low_res. In theory, the conditional moment estimation works if the full LR is passed as input and predicts the full HR. In practice, only the LR data that overlaps and surrounds the HR data is useful, ie E(HR|LR) = E(HR|LR_nei) where LR_nei is the LR data that surrounds the HR data. Physically, this is equivalent to saying that data far away from a region of interest does not matter. This allows learning the conditional moments on spatial and temporal chunks only if one restricts the high_res output as being overlapped and surrounded by the input low_res. The role of the mask is to ensure that the input low_res always surrounds the output high_res.
- Parameters:
high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)
- Returns:
mask (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)
- make_output(samples)#
For the 1st moment the output is simply the high_res
- post_init_log(args_dict=None)#
Log additional arguments after initialization.
- post_proc(samples)#
Returns normalized collection of samples / observations along with mask and target output for conditional moment estimation. Performs coarsening on high-res data if
Collectionconsists ofSamplerobjects and notDualSamplerobjects- Returns:
DsetTuple – Namedtuple-like object with low_res, high_res, mask, and output attributes
- preflight()#
Run checks before kicking off the queue.
- property queue_futures#
Get number of scheduled futures that will eventually add batches to the queue.
- property queue_len#
Get number of batches in the queue.
- property queue_shape#
Shape of objects stored in the queue.
- property queue_thread#
Get new queue thread.
- property running#
Boolean to check whether to keep enqueueing batches.
- sample_batch()#
Get random sampler from collection and return a batch of samples from that sampler.
Notes
These samples are wrapped in an
np.asarraycall, so they have been loaded into memory.
- sample_batches(n_batches) None#
Sample given number of batches either in serial or with thread pool.
- property shape#
Get shape of underlying data.
- property shapes#
Shapes of batches returned by
__next__
- start()#
Start the val data batch queue in addition to the train batch queue.
- stop()#
Stop the val data batch queue in addition to the train batch queue.
- transform(samples, smoothing=None, smoothing_ignore=None, temporal_coarsening_method='subsample')#
Coarsen high res data to get corresponding low res batch.
- Parameters:
samples (Union[np.ndarray, da.core.Array]) – High resolution batch of samples. 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)
smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.
smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None
temporal_coarsening_method (str) – Method to use for temporal coarsening. Can be subsample, average, min, max, or total
- Returns:
low_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)
high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)
- wrap(data)#
Return a
Sup3rDatasetobject or tuple of such. This is a tuple when the.dataattribute belongs to aCollectionobject likeBatchHandler. Otherwise this isSup3rDatasetobject, which is either a wrapped 3-tuple, 2-tuple, or 1-tuple (e.g.len(data) == 3,len(data) == 2orlen(data) == 1). This is a 3-tuple when.databelongs to a container object likeDualSamplerWithObs, a 2-tuple when.databelongs to a dual container object likeDualSampler, and a 1-tuple otherwise.