rbnet.multivariate_normal.ApproximateMixture
- class rbnet.multivariate_normal.ApproximateMixture(means, log_weights=None, covariances=None, cat=False)[source]
Bases:
object
Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood if means are data points).
- Parameters:
means (
Union
[Tensor
,Iterable
[Tensor
]]) – (N,X,D) array with means / locations of data points; X are arbitrary batch dimensionslog_weights (
Union
[Tensor
,Iterable
[Tensor
],None
]) – (N,X) array of weights of the components (optional; default is to assume uniform weights)covariances (
Union
[Tensor
,Iterable
[Tensor
],None
]) – (N,X,D,D) array of covariance matrices of the components (optional; default is to assume zero covariance, which corresponds to the components being treated as single data points)cat – If true assume the inputs are iterables of tensors, which need to be concatenated first
Public Methods:
__init__
(means[, log_weights, covariances, cat])Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood if means are data points).