rbnet.multivariate_normal.MultivariateNormal
- class rbnet.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, norm=None, log_norm=None, validate_args=None, dim=False)[source]
Bases:
objectRepresents a random variable with a multivariate normal distribution. It wraps PyTorch’s
torch.distributions.MultivariateNormal(available via thetorchproperty) but extends it with additional functionality.Initialise variable with given mean and covariance or precision matrix.
loc- Parameters:
loc (
Tensor)covariance_matrix (
Optional[Tensor])precision_matrix (
Optional[Tensor])scale_tril (
Optional[Tensor])norm (
Optional[Tensor])log_norm (
Optional[Tensor])validate_args (
Optional[bool])dim (
Union[int,bool]) – explicitly specify the event dimensionality if loc is a (batch of) scalars; otherwise this is taken to correspond to the last axis of loc
Public Data Attributes:
Public Methods:
__init__(loc[, covariance_matrix, ...])Initialise variable with given mean and covariance or precision matrix.
__mul__(other)Private Methods:
_expand_scalar(diag, dim)Expand an array of scalars into an array of diagonal matrices.
_expand_diagonal(diag[, dim])Expand an array of diagonals into an array of diagonal matrices.
_to_tensor(t)Return t as a pytorch tensor.
_expand_mat(mat)
- property torch