rbnet.multivariate_normal.MultivariateNormal

class rbnet.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, norm=None, log_norm=None, validate_args=None, dim=False)[source]

Bases: object

Represents a random variable with a multivariate normal distribution. It wraps PyTorch’s torch.distributions.MultivariateNormal (available via the torch property) but extends it with additional functionality.

Initialise variable with given mean and covariance or precision matrix. loc

Parameters:
  • loc (Tensor)

  • covariance_matrix (Optional[Tensor])

  • precision_matrix (Optional[Tensor])

  • scale_tril (Optional[Tensor])

  • norm (Optional[Tensor])

  • log_norm (Optional[Tensor])

  • validate_args (Optional[bool])

  • dim (Union[int, bool]) – explicitly specify the event dimensionality if loc is a (batch of) scalars; otherwise this is taken to correspond to the last axis of loc

Public Data Attributes:

torch

Public Methods:

__init__(loc[, covariance_matrix, ...])

Initialise variable with given mean and covariance or precision matrix.

__mul__(other)

Private Methods:

_expand_scalar(diag, dim)

Expand an array of scalars into an array of diagonal matrices.

_expand_diagonal(diag[, dim])

Expand an array of diagonals into an array of diagonal matrices.

_to_tensor(t)

Return t as a pytorch tensor.

_expand_mat(mat)


property torch