rbnet.multivariate_normal.MultivariateNormal
- class rbnet.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, norm=None, log_norm=None, validate_args=None, dim=False)[source]
Bases:
object
Represents a random variable with a multivariate normal distribution. It wraps PyTorch’s
torch.distributions.MultivariateNormal
(available via thetorch
property) but extends it with additional functionality.Initialise variable with given mean and covariance or precision matrix.
loc
- Parameters:
loc (
Tensor
)covariance_matrix (
Optional
[Tensor
])precision_matrix (
Optional
[Tensor
])scale_tril (
Optional
[Tensor
])norm (
Optional
[Tensor
])log_norm (
Optional
[Tensor
])validate_args (
Optional
[bool
])dim (
Union
[int
,bool
]) – explicitly specify the event dimensionality if loc is a (batch of) scalars; otherwise this is taken to correspond to the last axis of loc
Public Data Attributes:
Public Methods:
__init__
(loc[, covariance_matrix, ...])Initialise variable with given mean and covariance or precision matrix.
__mul__
(other)Private Methods:
_expand_scalar
(diag, dim)Expand an array of scalars into an array of diagonal matrices.
_expand_diagonal
(diag[, dim])Expand an array of diagonals into an array of diagonal matrices.
_to_tensor
(t)Return t as a pytorch tensor.
_expand_mat
(mat)
- property torch