rbnet.pcfg.DiscreteBinaryNonTerminalTransition
- class rbnet.pcfg.DiscreteBinaryNonTerminalTransition(weights, left_idx=0, right_idx=0, prob_rep=<class 'rbnet.util.LogProb'>, *args, **kwargs)[source]
Bases:
Transition
,ConstrainedModuleMixin
A binary non-terminal transition for discrete non-terminal variables.
Initialise a non-terminal transition p(a, b | c) for random variables a, b, c. The child variables b and c may be different variables than a (if the RBN has multiple non-terminal variables), which is determined by the left and right index (the default is to assume index 0, which is the first non-terminal variable – not necessarily the same variable as a).
- Parameters:
weights – Numpy array of shape (K, L, M) with weights proportional to p(a, b | c), where K, L, M are the cardinalities of the variables a, b, c, respectively.
left_idx – index of the left child variable
right_idx – index of the right child variable
Public Data Attributes:
Inherited from
Transition
dump_patches
call_super_init
forward
(*input)Define the computation performed at every call.
Inherited from
Module
dump_patches
call_super_init
T_destination
training
Public Methods:
__init__
(weights[, left_idx, right_idx, ...])Initialise a non-terminal transition p(a, b | c) for random variables a, b, c.
inside_marginals
(location, inside_chart, ...)Compute the marginals over inside probabilities
Inherited from
Transition
__init__
(*args, **kwargs)inside_marginals
(location, inside_chart, ...)Compute the marginals over inside probabilities
Inherited from
Module
__init__
(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward
(*input)Define the computation performed at every call.
register_buffer
(name, tensor[, persistent])Add a buffer to the module.
register_parameter
(name, param)Add a parameter to the module.
add_module
(name, module)Add a child module to the current module.
register_module
(name, module)Alias for
add_module()
.get_submodule
(target)Return the submodule given by
target
if it exists, otherwise throw an error.set_submodule
(target, module)Set the submodule given by
target
if it exists, otherwise throw an error.get_parameter
(target)Return the parameter given by
target
if it exists, otherwise throw an error.get_buffer
(target)Return the buffer given by
target
if it exists, otherwise throw an error.get_extra_state
()Return any extra state to include in the module's state_dict.
set_extra_state
(state)Set extra state contained in the loaded state_dict.
apply
(fn)Apply
fn
recursively to every submodule (as returned by.children()
) as well as self.cuda
([device])Move all model parameters and buffers to the GPU.
ipu
([device])Move all model parameters and buffers to the IPU.
xpu
([device])Move all model parameters and buffers to the XPU.
mtia
([device])Move all model parameters and buffers to the MTIA.
cpu
()Move all model parameters and buffers to the CPU.
type
(dst_type)Casts all parameters and buffers to
dst_type
.float
()Casts all floating point parameters and buffers to
float
datatype.double
()Casts all floating point parameters and buffers to
double
datatype.half
()Casts all floating point parameters and buffers to
half
datatype.bfloat16
()Casts all floating point parameters and buffers to
bfloat16
datatype.to_empty
(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
to
(*args, **kwargs)Move and/or cast the parameters and buffers.
register_full_backward_pre_hook
(hook[, prepend])Register a backward pre-hook on the module.
register_backward_hook
(hook)Register a backward hook on the module.
register_full_backward_hook
(hook[, prepend])Register a backward hook on the module.
register_forward_pre_hook
(hook, *[, ...])Register a forward pre-hook on the module.
register_forward_hook
(hook, *[, prepend, ...])Register a forward hook on the module.
__call__
(*args, **kwargs)Call self as a function.
__getstate__
()__setstate__
(state)__getattr__
(name)__setattr__
(name, value)Implement setattr(self, name, value).
__delattr__
(name)Implement delattr(self, name).
register_state_dict_post_hook
(hook)Register a post-hook for the
state_dict()
method.register_state_dict_pre_hook
(hook)Register a pre-hook for the
state_dict()
method.state_dict
(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
register_load_state_dict_pre_hook
(hook)Register a pre-hook to be run before module's
load_state_dict()
is called.register_load_state_dict_post_hook
(hook)Register a post-hook to be run after module's
load_state_dict()
is called.load_state_dict
(state_dict[, strict, assign])Copy parameters and buffers from
state_dict
into this module and its descendants.parameters
([recurse])Return an iterator over module parameters.
named_parameters
([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
buffers
([recurse])Return an iterator over module buffers.
named_buffers
([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
children
()Return an iterator over immediate children modules.
named_children
()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
modules
()Return an iterator over all modules in the network.
named_modules
([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
train
([mode])Set the module in training mode.
eval
()Set the module in evaluation mode.
requires_grad_
([requires_grad])Change if autograd should record operations on parameters in this module.
zero_grad
([set_to_none])Reset gradients of all model parameters.
share_memory
()See
torch.Tensor.share_memory_()
.extra_repr
()Set the extra representation of the module.
__repr__
()Return repr(self).
__dir__
()Default dir() implementation.
compile
(*args, **kwargs)Compile this Module's forward using
torch.compile()
.Inherited from
ConstrainedModuleMixin
enforce_constraints
([recurse])Enforce constraints for module parameters and child modules.
remap
(param[, _top_level, prefix])Private Data Attributes:
_abc_impl
Inherited from
Transition
_abc_impl
_version
This allows better BC support for
load_state_dict()
._parameters
_buffers
_non_persistent_buffers_set
_backward_pre_hooks
_backward_hooks
_is_full_backward_hook
_forward_hooks
_forward_hooks_with_kwargs
_forward_hooks_always_called
_forward_pre_hooks
_forward_pre_hooks_with_kwargs
_state_dict_hooks
_load_state_dict_pre_hooks
_state_dict_pre_hooks
_load_state_dict_post_hooks
_modules
_compiled_call_impl
Inherited from
ABC
_abc_impl
Inherited from
Module
_version
This allows better BC support for
load_state_dict()
._compiled_call_impl
_parameters
_buffers
_non_persistent_buffers_set
_backward_pre_hooks
_backward_hooks
_is_full_backward_hook
_forward_hooks
_forward_hooks_with_kwargs
_forward_hooks_always_called
_forward_pre_hooks
_forward_pre_hooks_with_kwargs
_state_dict_hooks
_load_state_dict_pre_hooks
_state_dict_pre_hooks
_load_state_dict_post_hooks
_modules
Private Methods:
Inherited from
Module
_apply
(fn[, recurse])_get_backward_hooks
()Return the backward hooks for use in the call function.
_get_backward_pre_hooks
()_maybe_warn_non_full_backward_hook
(inputs, ...)_slow_forward
(*input, **kwargs)_wrapped_call_impl
(*args, **kwargs)_call_impl
(*args, **kwargs)_register_state_dict_hook
(hook)Register a post-hook for the
state_dict()
method._save_to_state_dict
(destination, prefix, ...)Save module state to the destination dictionary.
_register_load_state_dict_pre_hook
(hook[, ...])See
register_load_state_dict_pre_hook()
for details._load_from_state_dict
(state_dict, prefix, ...)Copy parameters and buffers from
state_dict
into only this module, but not its descendants._named_members
(get_members_fn[, prefix, ...])Help yield various names + members of modules.
_get_name
()_replicate_for_data_parallel
()
- inside_marginals(location, inside_chart, terminal_chart, **kwargs)[source]
Compute the marginals over inside probabilities
for all possible splitting points (also see here). In particular,
location
specifies the variable’s location in the parse chart (the indices and in the equation above), from which the possible splitting points follow ( splitting points for transitions of arity ). The marginals should always be returned in an array or iterable where the first dimension corresponds to all possible combinations of splitting points, even for transitions with arity (i.e. for , where there are no splits, the first dimension should be of size 1 and for all possible combinations of the splitting points should be listed in a flattened form in the first dimension). Additional dimensions, may be used to represent the dependency of the marginal on the variable (e.g. for a discrete variable, the second dimension may list the marginal for each possible value can take; and for a continuous variable, the marginal may be represented by a set of parameters).The output of this function is typically handled by a custom implementation of
Cell.inside_mixture()
.- Parameters:
location – location of the variable for which to compute the inside marginals
inside_chart – a lookup chart with inside probabilities for other variables
terminal_chart – a lookup chart with values of the terminal variables
- Returns:
array-like or iterable with inside probabilities