rbnet.util

Functions

as_detached_tensor(t)

Create a detached copy of tensor.

ensure_is_floating_point(t[, msg])

log_normalize(t, *args, **kwargs)

Normalise tensor t in log representation by computing

using PyTorch logsumexp.

normalize_non_zero(a[, axis, ...])

For the given ND array (NumPy or PyTorch), normalise each 1D array obtained by indexing the 'axis' dimension if the sum along the other dimensions (for that entry) is non-zero.

plot_grad(func[, x_min, y_min, x_max, ...])

plot_vec(func[, x_min, y_min, x_max, y_max, ...])

Classes

ConstrainedModuleList([modules])

A plain ModuleList with ConstrainedModuleMixin to be cooperative and not break recursive calls.

ConstrainedModuleMixin()

A mixin class for modules with constraints to work cooperatively.

LogProb([p, log_p, dim, raise_zero_norms])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Prob(p[, dim, raise_zero_norms])

A class for probability distributions that enforces positivity and normalisation constraints and projects the gradient in backward passes.

SequenceDataModule(sequences[, val_split, ...])

Attributes:

TupleTMap(arrs, *args, **kwargs)

A tuple of TMap objects.