rbnet.multivariate_normal.ApproximateMixture

class rbnet.multivariate_normal.ApproximateMixture(means, log_weights=None, covariances=None, cat=False)[source]

Bases: object

Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood if means are data points).

Parameters:
  • means (Union[Tensor, Iterable[Tensor]]) – (N,X,D) array with means / locations of data points; X are arbitrary batch dimensions

  • log_weights (Union[Tensor, Iterable[Tensor], None]) – (N,X) array of weights of the components (optional; default is to assume uniform weights)

  • covariances (Union[Tensor, Iterable[Tensor], None]) – (N,X,D,D) array of covariance matrices of the components (optional; default is to assume zero covariance, which corresponds to the components being treated as single data points)

  • cat – If true assume the inputs are iterables of tensors, which need to be concatenated first

Public Methods:

__init__(means[, log_weights, covariances, cat])

Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood if means are data points).