sup3r.preprocessing.batch_handlers.factory.DualBatchHandler#

class DualBatchHandler(train_containers, *, val_containers=None, sample_shape=None, batch_size=16, n_batches=64, s_enhance=1, t_enhance=1, means=None, stds=None, queue_cap=None, transform_kwargs=None, mode='lazy', feature_sets=None, proxy_obs_kwargs=None, verbose=False, **kwargs)#

Bases: DualBatchQueue

BatchHandler object built from two lists of Container objects, one with training data and one with validation data. These lists will be used to initialize lists of class:Sampler objects that will then be used to build batches at run time.

Notes

These lists of containers can contain data from the same underlying data source (e.g. CONUS WTK) (e.g. initialize train / val containers with different time period and / or regions, or they can be used to sample from completely different data sources (e.g. train on CONUS WTK while validating on Canada WTK).

See also

Sampler, AbstractBatchQueue, StatsCollection

Parameters:
  • train_containers (list[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of training data. The data can be a Sup3rX or Sup3rDataset object.

  • val_containers (list[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of validation data. The data can be a Sup3rX or a Sup3rDataset object.

  • batch_size (int) – Number of observations / samples in a batch

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.

  • t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.

  • means (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.

  • stds (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.

  • queue_cap (int) – Maximum number of batches the batch queue can store.

  • transform_kwargs (Union[dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.

  • mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.

  • feature_sets (Optional[dict]) – Optional dictionary describing how the full set of features is split between lr_features, hr_exo_features, and hr_out_features.

    lr_featureslist | tuple

    List of feature names or patt*erns to use as low-resolution model inputs. If no entry is provided then all available features from the data will be used.

    hr_out_featureslist | tuple

    List of feature names or patt*erns that should be output by the generative model and available as ground truth targets. If no entry is provided then all features in the high res data will be used.

    hr_exo_featureslist | tuple

    List of feature names or patt*erns that should be available as high-resolution model inputs (like topography or observations) or bespoke loss functions. Features used for input are injected into the model mid-network to condition output on high-resolution information. The model configuration should have the appropriate layers to use these features. e.g. Sup3rConcat for topography injection, Sup3rObsModel or Sup3rCrossAttention for obs injection. If no entry is provided then hr_exo_features will be empty.

    *To include sparse features as inputs or targets the features must have an “_obs” suffix.

  • kwargs (dict) – Additional keyword arguments for BatchQueue and / or Samplers. This can vary depending on the type of BatchQueue / Sampler given to the Factory. For example, to build a BatchHandlerDC object (data-centric batch handler) we use a queue and sampler which takes spatial and temporal weight / bin arguments used to determine how to weigh spatiotemporal regions when sampling. Using ConditionalBatchQueue will result in arguments for computing moments from batches and how to pad batch data to enable these calculations.

  • sample_shape (tuple) – Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement factors.

  • proxy_obs_kwargs (dict | None) – Optional dictionary of keyword arguments to pass to the proxy observation generator. This is only used when training with proxy observations. Keys can include onshore_obs_frac, offshore_obs_frac, and perturbation_scale.

    perturbation_scalefloat

    Scale of the perturbation to add to the proxy observations when using proxy observations. This specifies the multiplier of the noise sampled from (-standard deviation, standard deviation). The standdard deviation is calculated per feature over each batch.

    onshore_obs_fracfloat | dict

    Fraction of onshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.

    offshore_obs_fracfloat | dict

    Fraction of offshore observations to include in each batch when using proxy observations. This can be a single float or a dictionary with keys ‘spatial’ and ‘temporal’ to specify the fraction for each domain. If a dictionary is provided, the actual fraction for each batch will be sampled uniformly between the specified spatial and temporal fractions.

  • verbose (bool) – Whether to log timing information for batch steps.

Methods

check_enhancement_factors()

Make sure each DualSampler has the same enhancment factors and they match those provided to the BatchQueue.

check_features()

Make sure all samplers have the same sets of features.

check_shared_attr(attr)

Check if all containers have the same value for attr.

enqueue_batches()

Callback function for queue thread.

get_batch()

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()

Get random container index based on weights

get_queue()

Return FIFO queue for storing batches.

get_random_container()

Get random container based on container weights

init_samplers(train_containers, ...)

Initialize samplers from given data containers.

log_queue_info()

Log info about queue size.

post_init_log([args_dict])

Log additional arguments after initialization.

post_proc(samples)

Performs some post proc on dequeued samples before sending out for training.

preflight()

Run checks before kicking off the queue.

sample_batch()

Get random sampler from collection and return a batch of samples from that sampler.

sample_batches(n_batches)

Sample given number of batches either in serial or with thread pool.

start()

Start the val data batch queue in addition to the train batch queue.

stop()

Stop the val data batch queue in addition to the train batch queue.

transform(samples[, smoothing, smoothing_ignore])

Perform smoothing if requested.

wrap(data)

Return a Sup3rDataset object or tuple of such.

Attributes

timer

BATCH_MEMBERS

container_weights

Get weights used to sample from different containers based on relative sizes

data

Return underlying data.

features

Get all features contained in data.

hr_shape

Shape of high resolution sample in a low-res / high-res pair.

lr_shape

Shape of low resolution sample in a low-res / high-res pair.

queue_futures

Get number of scheduled futures that will eventually add batches to the queue.

queue_len

Get number of batches in the queue.

queue_shape

Shape of objects stored in the queue.

queue_thread

Get new queue thread.

running

Boolean to check whether to keep enqueueing batches.

shape

Get shape of underlying data.

shapes

Shapes of batches returned by __next__

SAMPLER#

alias of DualSampler

TRAIN_QUEUE#

alias of DualBatchQueue

VAL_QUEUE#

alias of DualBatchQueue

check_enhancement_factors()#

Make sure each DualSampler has the same enhancment factors and they match those provided to the BatchQueue.

check_features()#

Make sure all samplers have the same sets of features.

check_shared_attr(attr)#

Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.

property container_weights#

Get weights used to sample from different containers based on relative sizes

property data#

Return underlying data.

Returns:

Sup3rDataset

See also

wrap()

enqueue_batches() None#

Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.

property features#

Get all features contained in data.

get_batch() DsetTuple#

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()#

Get random container index based on weights

get_queue()#

Return FIFO queue for storing batches.

get_random_container()#

Get random container based on container weights

property hr_shape#

Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

init_samplers(train_containers, val_containers, sample_shape, feature_sets, batch_size, mode, sampler_kwargs)#

Initialize samplers from given data containers.

log_queue_info()#

Log info about queue size.

property lr_shape#

Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

post_init_log(args_dict=None)#

Log additional arguments after initialization.

post_proc(samples) DsetTuple#

Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if Collection consists of Sampler objects and not DualSampler objects), smoothing, etc

Returns:

Batch (DsetTuple) – namedtuple-like object with low_res and high_res attributes. Could also include obs member.

preflight()#

Run checks before kicking off the queue.

property queue_futures#

Get number of scheduled futures that will eventually add batches to the queue.

property queue_len#

Get number of batches in the queue.

property queue_shape#

Shape of objects stored in the queue. Optionally includes shape of observation data which would be included in an extra content loss term

property queue_thread#

Get new queue thread.

property running#

Boolean to check whether to keep enqueueing batches.

sample_batch()#

Get random sampler from collection and return a batch of samples from that sampler.

Notes

These samples are wrapped in an np.asarray call, so they have been loaded into memory.

sample_batches(n_batches) None#

Sample given number of batches either in serial or with thread pool.

property shape#

Get shape of underlying data.

property shapes#

Shapes of batches returned by __next__

start()#

Start the val data batch queue in addition to the train batch queue.

stop()#

Stop the val data batch queue in addition to the train batch queue.

transform(samples, smoothing=None, smoothing_ignore=None)#

Perform smoothing if requested.

Note

This does not include temporal or spatial coarsening like SingleBatchQueue

wrap(data)#

Return a Sup3rDataset object or tuple of such. This is a tuple when the .data attribute belongs to a Collection object like BatchHandler. Otherwise this is Sup3rDataset object, which is either a wrapped 3-tuple, 2-tuple, or 1-tuple (e.g. len(data) == 3, len(data) == 2 or len(data) == 1). This is a 3-tuple when .data belongs to a container object like DualSamplerWithObs, a 2-tuple when .data belongs to a dual container object like DualSampler, and a 1-tuple otherwise.