rbnet.pcfg.AbstractedPCFG
- class rbnet.pcfg.AbstractedPCFG(non_terminals, terminals, rules, start, prob_rep=<class 'rbnet.util.LogProb'>, *args, **kwargs)[source]
Bases:
PCFG
,LightningModule
,ConstrainedModuleMixin
An
AbstractedPCFG
defines anRBN
that has only one non-terminal and one terminal variable, both being discrete with a cardinality corresponding to the number of non-terminal and terminal symbols of the PCFG, respectively.- Parameters:
non_terminals – list or array of non-terminal symbols
terminals – list or array of terminal symbols
rules – iterable of rules-weight tuples with rules provided either as strings of the form
("X --> Y Z", w)
or("X --> Y", w)
for non-terminal and terminal rules, respectively (symbols have to be strings without whitespace for this), or of the form((X, (Y, Z)), w)
or((X, (Y,)), w)
for arbitrary symbols, where w is the rule weight.start – the start symbol
Public Data Attributes:
Inherited from
SequentialRBN
Return the location of the root variables.
Return the chart with inside probabilities for all variables.
Return the chart with terminal variables.
Return the prior transition (typically an instance of
Prior
), which has to implementPrior.marginal_likelihood()
.dump_patches
training
call_super_init
forward
(*input)Define the computation performed at every call.
Inherited from
RBN
Return the location of the root variables.
Return the chart with inside probabilities for all variables.
Return the chart with terminal variables.
Return the prior transition (typically an instance of
Prior
), which has to implementPrior.marginal_likelihood()
.Inherited from
LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
trainer
fabric
example_input_array
The example input array is a specification of what the module can consume in the
forward()
method.current_epoch
The current epoch in the
Trainer
, or 0 if not attached.global_step
Total training batches seen across all epochs.
global_rank
The index of the current process across all nodes and devices.
local_rank
The index of the current process within a single node.
on_gpu
Returns
True
if this model is currently located on a GPU.automatic_optimization
If set to
False
you are responsible for calling.backward()
,.step()
,.zero_grad()
.strict_loading
Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
logger
Reference to the logger object in the Trainer.
loggers
Reference to the list of loggers in the Trainer.
device_mesh
Strategies like
ModelParallelStrategy
will create a device mesh that can be accessed in theconfigure_model()
hook to parallelize the LightningModule.Inherited from
_DeviceDtypeModuleMixin
dtype
device
Inherited from
HyperparametersMixin
hparams
The collection of hyperparameters saved with
save_hyperparameters()
.hparams_initial
The collection of hyperparameters saved with
save_hyperparameters()
.Inherited from
Module
dump_patches
call_super_init
T_destination
training
Public Methods:
__init__
(non_terminals, terminals, rules, start)An
AbstractedPCFG
defines anRBN
that has only one non-terminal and one terminal variable, both being discrete with a cardinality corresponding to the number of non-terminal and terminal symbols of the PCFG, respectively.Inherited from
PCFG
__init__
(cells, prior, terminal_indices, ...)tokenise
(sequence)init_inside
(sequence)Initialise for parsing a new input.
map_inside_chart
([precision])Inherited from
SequentialRBN
__init__
(cells, prior, *args, **kwargs)init_inside
(sequence)Initialise for parsing a new input.
inside_schedule
(*args, **kwargs)Iterate through (batches of) non-terminal locations for computing inside probabilities.
cells
()Return iterable over cells (corresponding to the non-terminal variables).
update_inside_chart
(var_idx, locations, values)For the specified variable, update the chart for inside probabilities with given values at given locations.
Inherited from
RBN
__init__
(*args, **kwargs)inside_schedule
(*args, **kwargs)Iterate through (batches of) non-terminal locations for computing inside probabilities.
cells
()Return iterable over cells (corresponding to the non-terminal variables).
init_inside
(*args, **kwargs)Initialise for parsing a new input.
update_inside_chart
(var_idx, locations, values)For the specified variable, update the chart for inside probabilities with given values at given locations.
inside
(*args, **kwargs)Compute the inside probabilities and return the marginal data likelihood.
Inherited from
LightningModule
__init__
(*args, **kwargs)optimizers
([use_pl_optimizer])Returns the optimizer(s) that are being used during training.
lr_schedulers
()Returns the learning rate scheduler(s) that are being used during training.
print
(*args, **kwargs)Prints only from process 0.
log
(name, value[, prog_bar, logger, ...])Log a key, value pair.
log_dict
(dictionary[, prog_bar, logger, ...])Log a dictionary of values at once.
all_gather
(data[, group, sync_grads])Gather tensors or collections of tensors from multiple processes.
forward
(*args, **kwargs)Same as
torch.nn.Module.forward()
.training_step
(*args, **kwargs)Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
validation_step
(*args, **kwargs)Operates on a single batch of data from the validation set.
test_step
(*args, **kwargs)Operates on a single batch of data from the test set.
predict_step
(*args, **kwargs)Step function called during
predict()
.configure_callbacks
()Configure model-specific callbacks.
configure_optimizers
()Choose what optimizers and learning-rate schedulers to use in your optimization.
manual_backward
(loss, *args, **kwargs)Call this directly from your
training_step()
when doing optimizations manually.backward
(loss, *args, **kwargs)Called to perform backward on the loss returned in
training_step()
.toggle_optimizer
(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
untoggle_optimizer
(optimizer)Resets the state of required gradients that were toggled with
toggle_optimizer()
.clip_gradients
(optimizer[, ...])Handles gradient clipping internally.
configure_gradient_clipping
(optimizer[, ...])Perform gradient clipping for the optimizer parameters.
lr_scheduler_step
(scheduler, metric)Override this method to adjust the default way the
Trainer
calls each scheduler.optimizer_step
(epoch, batch_idx, optimizer)Override this method to adjust the default way the
Trainer
calls the optimizer.optimizer_zero_grad
(epoch, batch_idx, optimizer)Override this method to change the default behaviour of
optimizer.zero_grad()
.freeze
()Freeze all params for inference.
unfreeze
()Unfreeze all parameters for training.
to_onnx
(file_path[, input_sample])Saves the model in ONNX format.
to_torchscript
([file_path, method, ...])By default compiles the whole model to a
ScriptModule
.load_from_checkpoint
(checkpoint_path[, ...])Primary way of loading a model from a checkpoint.
__getstate__
()Inherited from
_DeviceDtypeModuleMixin
__init__
()Initialize internal Module state, shared by both nn.Module and ScriptModule.
to
(*args, **kwargs)See
torch.nn.Module.to()
.cuda
([device])Moves all model parameters and buffers to the GPU.
cpu
()See
torch.nn.Module.cpu()
.type
(dst_type)See
torch.nn.Module.type()
.float
()See
torch.nn.Module.float()
.double
()See
torch.nn.Module.double()
.half
()See
torch.nn.Module.half()
.Inherited from
HyperparametersMixin
__init__
()save_hyperparameters
(*args[, ignore, frame, ...])Save arguments to
hparams
attribute.Inherited from
ModelHooks
on_fit_start
()Called at the very beginning of fit.
on_fit_end
()Called at the very end of fit.
on_train_start
()Called at the beginning of training after sanity check.
on_train_end
()Called at the end of training before logger experiment is closed.
on_validation_start
()Called at the beginning of validation.
on_validation_end
()Called at the end of validation.
on_test_start
()Called at the beginning of testing.
on_test_end
()Called at the end of testing.
on_predict_start
()Called at the beginning of predicting.
on_predict_end
()Called at the end of predicting.
on_train_batch_start
(batch, batch_idx)Called in the training loop before anything happens for that batch.
on_train_batch_end
(outputs, batch, batch_idx)Called in the training loop after the batch.
on_validation_batch_start
(batch, batch_idx)Called in the validation loop before anything happens for that batch.
on_validation_batch_end
(outputs, batch, ...)Called in the validation loop after the batch.
on_test_batch_start
(batch, batch_idx[, ...])Called in the test loop before anything happens for that batch.
on_test_batch_end
(outputs, batch, batch_idx)Called in the test loop after the batch.
on_predict_batch_start
(batch, batch_idx[, ...])Called in the predict loop before anything happens for that batch.
on_predict_batch_end
(outputs, batch, batch_idx)Called in the predict loop after the batch.
on_validation_model_zero_grad
()Called by the training loop to release gradients before entering the validation loop.
on_validation_model_eval
()Called when the validation loop starts.
on_validation_model_train
()Called when the validation loop ends.
on_test_model_eval
()Called when the test loop starts.
on_test_model_train
()Called when the test loop ends.
on_predict_model_eval
()Called when the predict loop starts.
on_train_epoch_start
()Called in the training loop at the very beginning of the epoch.
on_train_epoch_end
()Called in the training loop at the very end of the epoch.
on_validation_epoch_start
()Called in the validation loop at the very beginning of the epoch.
on_validation_epoch_end
()Called in the validation loop at the very end of the epoch.
on_test_epoch_start
()Called in the test loop at the very beginning of the epoch.
on_test_epoch_end
()Called in the test loop at the very end of the epoch.
on_predict_epoch_start
()Called at the beginning of predicting.
on_predict_epoch_end
()Called at the end of predicting.
on_before_zero_grad
(optimizer)Called after
training_step()
and beforeoptimizer.zero_grad()
.on_before_backward
(loss)Called before
loss.backward()
.on_after_backward
()Called after
loss.backward()
and before optimizers are stepped.on_before_optimizer_step
(optimizer)Called before
optimizer.step()
.configure_sharded_model
()Deprecated.
configure_model
()Hook to create modules in a strategy and precision aware context.
Inherited from
DataHooks
__init__
()Attributes:
prepare_data
()Use this to download and prepare data.
setup
(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
teardown
(stage)Called at the end of fit (train + validate), validate, test, or predict.
train_dataloader
()An iterable or collection of iterables specifying training samples.
test_dataloader
()An iterable or collection of iterables specifying test samples.
val_dataloader
()An iterable or collection of iterables specifying validation samples.
predict_dataloader
()An iterable or collection of iterables specifying prediction samples.
transfer_batch_to_device
(batch, device, ...)Override this hook if your
DataLoader
returns tensors wrapped in a custom data structure.on_before_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_after_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
Inherited from
CheckpointHooks
on_load_checkpoint
(checkpoint)Called by Lightning to restore your model.
on_save_checkpoint
(checkpoint)Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
Inherited from
Module
__init__
(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward
(*input)Define the computation performed at every call.
register_buffer
(name, tensor[, persistent])Add a buffer to the module.
register_parameter
(name, param)Add a parameter to the module.
add_module
(name, module)Add a child module to the current module.
register_module
(name, module)Alias for
add_module()
.get_submodule
(target)Return the submodule given by
target
if it exists, otherwise throw an error.set_submodule
(target, module)Set the submodule given by
target
if it exists, otherwise throw an error.get_parameter
(target)Return the parameter given by
target
if it exists, otherwise throw an error.get_buffer
(target)Return the buffer given by
target
if it exists, otherwise throw an error.get_extra_state
()Return any extra state to include in the module's state_dict.
set_extra_state
(state)Set extra state contained in the loaded state_dict.
apply
(fn)Apply
fn
recursively to every submodule (as returned by.children()
) as well as self.cuda
([device])Move all model parameters and buffers to the GPU.
ipu
([device])Move all model parameters and buffers to the IPU.
xpu
([device])Move all model parameters and buffers to the XPU.
mtia
([device])Move all model parameters and buffers to the MTIA.
cpu
()Move all model parameters and buffers to the CPU.
type
(dst_type)Casts all parameters and buffers to
dst_type
.float
()Casts all floating point parameters and buffers to
float
datatype.double
()Casts all floating point parameters and buffers to
double
datatype.half
()Casts all floating point parameters and buffers to
half
datatype.bfloat16
()Casts all floating point parameters and buffers to
bfloat16
datatype.to_empty
(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
to
(*args, **kwargs)Move and/or cast the parameters and buffers.
register_full_backward_pre_hook
(hook[, prepend])Register a backward pre-hook on the module.
register_backward_hook
(hook)Register a backward hook on the module.
register_full_backward_hook
(hook[, prepend])Register a backward hook on the module.
register_forward_pre_hook
(hook, *[, ...])Register a forward pre-hook on the module.
register_forward_hook
(hook, *[, prepend, ...])Register a forward hook on the module.
__call__
(*args, **kwargs)Call self as a function.
__getstate__
()__setstate__
(state)__getattr__
(name)__setattr__
(name, value)Implement setattr(self, name, value).
__delattr__
(name)Implement delattr(self, name).
register_state_dict_post_hook
(hook)Register a post-hook for the
state_dict()
method.register_state_dict_pre_hook
(hook)Register a pre-hook for the
state_dict()
method.state_dict
(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
register_load_state_dict_pre_hook
(hook)Register a pre-hook to be run before module's
load_state_dict()
is called.register_load_state_dict_post_hook
(hook)Register a post-hook to be run after module's
load_state_dict()
is called.load_state_dict
(state_dict[, strict, assign])Copy parameters and buffers from
state_dict
into this module and its descendants.parameters
([recurse])Return an iterator over module parameters.
named_parameters
([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
buffers
([recurse])Return an iterator over module buffers.
named_buffers
([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
children
()Return an iterator over immediate children modules.
named_children
()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
modules
()Return an iterator over all modules in the network.
named_modules
([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
train
([mode])Set the module in training mode.
eval
()Set the module in evaluation mode.
requires_grad_
([requires_grad])Change if autograd should record operations on parameters in this module.
zero_grad
([set_to_none])Reset gradients of all model parameters.
share_memory
()See
torch.Tensor.share_memory_()
.extra_repr
()Set the extra representation of the module.
__repr__
()Return repr(self).
__dir__
()Default dir() implementation.
compile
(*args, **kwargs)Compile this Module's forward using
torch.compile()
.Inherited from
ConstrainedModuleMixin
enforce_constraints
([recurse])Enforce constraints for module parameters and child modules.
remap
(param[, _top_level, prefix])Private Data Attributes:
_abc_impl
Inherited from
PCFG
_abc_impl
Inherited from
SequentialRBN
_abc_impl
_version
This allows better BC support for
load_state_dict()
._parameters
_buffers
_non_persistent_buffers_set
_backward_pre_hooks
_backward_hooks
_is_full_backward_hook
_forward_hooks
_forward_hooks_with_kwargs
_forward_hooks_always_called
_forward_pre_hooks
_forward_pre_hooks_with_kwargs
_state_dict_hooks
_load_state_dict_pre_hooks
_state_dict_pre_hooks
_load_state_dict_post_hooks
_modules
_compiled_call_impl
Inherited from
RBN
_abc_impl
Inherited from
ABC
_abc_impl
Inherited from
LightningModule
_jit_is_scripting
_trainer
_example_input_array
_automatic_optimization
_strict_loading
_current_fx_name
_param_requires_grad_state
_metric_attributes
_compiler_ctx
_fabric
_fabric_optimizers
_device_mesh
_dtype
_version
This allows better BC support for
load_state_dict()
._parameters
_buffers
_non_persistent_buffers_set
_backward_pre_hooks
_backward_hooks
_is_full_backward_hook
_forward_hooks
_forward_hooks_with_kwargs
_forward_hooks_always_called
_forward_pre_hooks
_forward_pre_hooks_with_kwargs
_state_dict_hooks
_load_state_dict_pre_hooks
_state_dict_pre_hooks
_load_state_dict_post_hooks
_modules
_compiled_call_impl
Inherited from
Module
_version
This allows better BC support for
load_state_dict()
._compiled_call_impl
_parameters
_buffers
_non_persistent_buffers_set
_backward_pre_hooks
_backward_hooks
_is_full_backward_hook
_forward_hooks
_forward_hooks_with_kwargs
_forward_hooks_always_called
_forward_pre_hooks
_forward_pre_hooks_with_kwargs
_state_dict_hooks
_load_state_dict_pre_hooks
_state_dict_pre_hooks
_load_state_dict_post_hooks
_modules
Private Methods:
Inherited from
LightningModule
_call_batch_hook
(hook_name, *args)_on_before_batch_transfer
(batch[, ...])_apply_batch_transfer_handler
(batch[, ...])_log_dict_through_fabric
(dictionary[, logger])_LightningModule__check_not_nested
(value, name)_LightningModule__check_allowed
(v, name, value)_LightningModule__to_tensor
(value, name)_verify_is_manual_optimization
(fn_name)Inherited from
HyperparametersMixin
_set_hparams
(hp)_to_hparams_dict
(hp)Inherited from
Module
_apply
(fn[, recurse])_get_backward_hooks
()Return the backward hooks for use in the call function.
_get_backward_pre_hooks
()_maybe_warn_non_full_backward_hook
(inputs, ...)_slow_forward
(*input, **kwargs)_wrapped_call_impl
(*args, **kwargs)_call_impl
(*args, **kwargs)_register_state_dict_hook
(hook)Register a post-hook for the
state_dict()
method._save_to_state_dict
(destination, prefix, ...)Save module state to the destination dictionary.
_register_load_state_dict_pre_hook
(hook[, ...])See
register_load_state_dict_pre_hook()
for details._load_from_state_dict
(state_dict, prefix, ...)Copy parameters and buffers from
state_dict
into only this module, but not its descendants._named_members
(get_members_fn[, prefix, ...])Help yield various names + members of modules.
_get_name
()_replicate_for_data_parallel
()