rbnet.util.SequenceDataModule
- class rbnet.util.SequenceDataModule(sequences, val_split=0.2, test_split=0.1)[source]
Bases:
LightningDataModule
- Attributes:
- prepare_data_per_node:
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices:
If True, dataloader with zero length within local rank is allowed. Default value is False.
Public Data Attributes:
Inherited from
LightningDataModule
name
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
Inherited from
HyperparametersMixin
hparams
The collection of hyperparameters saved with
save_hyperparameters()
.hparams_initial
The collection of hyperparameters saved with
save_hyperparameters()
.Public Methods:
__init__
(sequences[, val_split, test_split])Attributes:
setup
([stage])Called at the beginning of fit (train + validate), validate, test, or predict.
An iterable or collection of iterables specifying training samples.
An iterable or collection of iterables specifying validation samples.
An iterable or collection of iterables specifying test samples.
Inherited from
LightningDataModule
__init__
()Attributes:
from_datasets
([train_dataset, val_dataset, ...])Create an instance from torch.utils.data.Dataset.
state_dict
()Called when saving a checkpoint, implement to generate and save datamodule state.
load_state_dict
(state_dict)Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
on_exception
(exception)Called when the trainer execution is interrupted by an exception.
load_from_checkpoint
(checkpoint_path[, ...])Primary way of loading a datamodule from a checkpoint.
Inherited from
DataHooks
__init__
()Attributes:
prepare_data
()Use this to download and prepare data.
setup
(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
teardown
(stage)Called at the end of fit (train + validate), validate, test, or predict.
train_dataloader
()An iterable or collection of iterables specifying training samples.
test_dataloader
()An iterable or collection of iterables specifying test samples.
val_dataloader
()An iterable or collection of iterables specifying validation samples.
predict_dataloader
()An iterable or collection of iterables specifying prediction samples.
transfer_batch_to_device
(batch, device, ...)Override this hook if your
DataLoader
returns tensors wrapped in a custom data structure.on_before_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_after_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
Inherited from
HyperparametersMixin
__init__
()save_hyperparameters
(*args[, ignore, frame, ...])Save arguments to
hparams
attribute.Private Methods:
Inherited from
HyperparametersMixin
_set_hparams
(hp)_to_hparams_dict
(hp)
- setup(stage=None)[source]
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Args:
stage: either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Note:
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.
- train_dataloader()[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Note:
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.