rbnet.util.SequenceDataModule

class rbnet.util.SequenceDataModule(sequences, val_split=0.2, test_split=0.1)[source]

Bases: LightningDataModule

Attributes:
prepare_data_per_node:

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices:

If True, dataloader with zero length within local rank is allowed. Default value is False.

Public Data Attributes:

Inherited from LightningDataModule

name

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

Inherited from HyperparametersMixin

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

Public Methods:

__init__(sequences[, val_split, test_split])

Attributes:

setup([stage])

Called at the beginning of fit (train + validate), validate, test, or predict.

train_dataloader()

An iterable or collection of iterables specifying training samples.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

test_dataloader()

An iterable or collection of iterables specifying test samples.

Inherited from LightningDataModule

__init__()

Attributes:

from_datasets([train_dataset, val_dataset, ...])

Create an instance from torch.utils.data.Dataset.

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

on_exception(exception)

Called when the trainer execution is interrupted by an exception.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a datamodule from a checkpoint.

Inherited from DataHooks

__init__()

Attributes:

prepare_data()

Use this to download and prepare data.

setup(stage)

Called at the beginning of fit (train + validate), validate, test, or predict.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

train_dataloader()

An iterable or collection of iterables specifying training samples.

test_dataloader()

An iterable or collection of iterables specifying test samples.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Inherited from HyperparametersMixin

__init__()

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

Private Methods:

Inherited from HyperparametersMixin

_set_hparams(hp)

_to_hparams_dict(hp)


setup(stage=None)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note:

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note:

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.