Source code for rbnet.multivariate_normal

#  Copyright (c) 2021 Robert Lieck

import math
from functools import cached_property
import operator
from typing import Union
from collections.abc import Iterable

import torch
from torch.distributions import MultivariateNormal as TorchMultivariateNormal
import numpy as np


[docs] class MultivariateNormal: """ Represents a random variable with a multivariate normal distribution. It wraps PyTorch's :class:`torch.distributions.MultivariateNormal` (available via the :attr:`torch` property) but extends it with additional functionality. """ __FULL__ = "__FULL__" __DIAG__ = "__DIAG__" __SCAL__ = "__SCAL__" @classmethod def _expand_scalar(cls, diag, dim): """ Expand an array of scalars into an array of diagonal matrices. :param diag: Array of arbitrary shape (...) with scalars that specify the value on the diagonal of the matrices. :param dim: Dimensionality of the resulting matrices. :return: Array of shape (..., dim, dim) with diagonal matrices. .. testsetup:: from rbnet.multivariate_normal import MultivariateNormal import torch .. testcode:: diag = torch.arange(3, dtype=torch.float) print(diag) print(MultivariateNormal._expand_scalar(diag=diag, dim=3)) .. testoutput:: tensor([0., 1., 2.]) tensor([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]]]) """ diag_ = torch.zeros(diag.shape + (dim, dim)) diag_[..., torch.arange(dim), torch.arange(dim)] = diag[..., None] return diag_ @classmethod def _expand_diagonal(cls, diag, dim=None): """ Expand an array of diagonals into an array of diagonal matrices. :param diag: Array of shape (..., dim) with entries that specify the values on the diagonal of the matrices. :param dim: Dimensionality of the resulting matrices. :return: Array of shape (..., dim, dim) with diagonal matrices. """ if dim is None: dim = diag.shape[-1] else: if diag.shape[-1] != dim: raise ValueError(f"Last dimension of tensor has size {diag.shape[-1]}, " f"cannot expand to diagonal of size {dim}") mat_ = torch.zeros(diag.shape + (dim,)) mat_[..., torch.arange(dim), torch.arange(dim)] = diag return mat_ @classmethod def _to_tensor(cls, t): """ Return `t` as a pytorch tensor. If `t` already is a tensor it is returned as is, otherwise ``torch.tensor(t)`` is returned. """ if not isinstance(t, torch.Tensor): t = torch.tensor(t) return t def __init__(self, loc: torch.Tensor, covariance_matrix: Union[torch.Tensor, None] = None, precision_matrix: Union[torch.Tensor, None] = None, scale_tril: Union[torch.Tensor, None] = None, norm: Union[torch.Tensor, None] = None, log_norm: Union[torch.Tensor, None] = None, validate_args: Union[bool, None] = None, dim: Union[int, bool] = False): """ Initialise variable with given mean and covariance or precision matrix. :attr:`loc` :param loc: :param covariance_matrix: :param precision_matrix: :param scale_tril: :param norm: :param log_norm: :param validate_args: :param dim: explicitly specify the event dimensionality if `loc` is a (batch of) scalars; otherwise this is taken to correspond to the last axis of `loc` """ # make sure at most one of 'norm' and 'log_norm' is provided if (norm is not None) + (log_norm is not None) > 1: raise ValueError("At most one of 'norm' and 'log_norm' may be specified.") if norm is None and log_norm is None: self._log_norm = self._to_tensor(0.) elif log_norm is not None: self._log_norm = self._to_tensor(log_norm) else: # norm is not None self._log_norm = np.log(self._to_tensor(norm)) # get the location self._loc = self._to_tensor(loc) if dim: # event dimensionality explicitly defined self._event_dim = dim # expand scalar locations accordingly self._expanded_loc = self._loc[..., None].expand(self._loc.shape + (dim,)) else: # event dimensionality given by last axis of loc if self._loc.dim() < 1: # loc is scalar self._event_dim = self._loc.shape else: # loc is scalar self._event_dim = self._loc.shape[-1] self._expanded_loc = self._loc if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril must be specified.") if scale_tril is not None: self._scale_tril = mat = self._to_tensor(scale_tril) elif covariance_matrix is not None: self._covariance_matrix = mat = self._to_tensor(covariance_matrix) else: self._precision_matrix = mat = self._to_tensor(precision_matrix) self._mat_type = [self.__SCAL__, self.__DIAG__, self.__FULL__][mat.dim() - self._expanded_loc.dim() + 1] self._batch_shape = torch.broadcast_shapes(mat.shape[:len(self._expanded_loc.shape) - 1], self._expanded_loc.shape[:-1]) self._validate_args = validate_args self._torch = None def _expand_mat(self, mat: torch.Tensor) -> Union[torch.Tensor, None]: if self._mat_type == self.__SCAL__: mat = self._expand_scalar(mat, self._event_dim) elif self._mat_type == self.__DIAG__: mat = self._expand_diagonal(mat, self._event_dim) return mat @property def torch(self): if self._torch is None: if hasattr(self, "_scale_tril"): self._torch = TorchMultivariateNormal(loc=self._expanded_loc, scale_tril=self._expand_mat(self._scale_tril), validate_args=self._validate_args) elif hasattr(self, "_covariance_matrix"): self._torch = TorchMultivariateNormal(loc=self._expanded_loc, covariance_matrix=self._expand_mat(self._covariance_matrix), validate_args=self._validate_args) elif hasattr(self, "_precision_matrix"): self._torch = TorchMultivariateNormal(loc=self._expanded_loc, precision_matrix=self._expand_mat(self._precision_matrix), validate_args=self._validate_args) return self._torch def __mul__(self, other): if not isinstance(other, MultivariateNormal): return NotImplemented else: prod = PairwiseProduct(mean1=self._loc, cov1=self.torch.covariance_matrix, mean2=other._loc, cov2=other.torch.covariance_matrix) return MultivariateNormal(loc=prod.mean, scale_tril=prod.torch.scale_tril, log_norm=prod.log_norm + self._log_norm + other._log_norm)
[docs] class PairwiseProduct: def __init__(self, mean1, mean2, cov1=None, cov2=None, prec1=None, prec2=None): assert not (cov1 is None and prec1 is None), "either cov1 or prec1 needs to be provided" assert not (cov2 is None and prec2 is None), "either cov2 or prec2 needs to be provided" self.mean1 = mean1 self.mean2 = mean2 self._cov1 = cov1 self._cov2 = cov2 self._prec1 = prec1 self._prec2 = prec2 self._sum_cov = None self._sum_cov_inv = None self._cov = None self._prec = None self._torch = None @classmethod def _cost_comp(cls, cost1, cost2, comp): """ Compares cost1 and cost2 using comp. Costs are numpy arrays [i, mm, mv, ma, va] with i indicating the number of matrix inversion, mm that of matrix-matrix multiplications, mv that of matrix-vector multiplications, ma that of matrix-matrix additions, and va that of vector-vector additions. Comparison is performed element-wise in this order, that is:: if comp(i1, i2): return True elif not i1 == i2: return False continue with mm1 and mm2 etc... """ for c1, c2 in zip(cost1, cost2): if comp(c1, c2): return True elif not c1 == c2: return False return False @classmethod def _cost_lt(cls, cost1, cost2): return cls._cost_comp(cost1=cost1, cost2=cost2, comp=operator.lt) @property def cov1(self): if self._cov1 is None: self._cov1 = torch.inverse(self.prec1) return self._cov1 def _cov1_cost(self): return np.array([int(self._cov1 is None), 0, 0, 0, 0]) @property def cov2(self): if self._cov2 is None: self._cov2 = torch.inverse(self.prec2) return self._cov2 def _cov2_cost(self): return np.array([int(self._cov2 is None), 0, 0, 0, 0]) @property def prec1(self): if self._prec1 is None: self._prec1 = torch.inverse(self.cov1) return self._prec1 def _prec1_cost(self): return np.array([int(self._prec1 is None), 0, 0, 0, 0]) @property def prec2(self): if self._prec2 is None: self._prec2 = torch.inverse(self.cov2) return self._prec2 def _prec2_cost(self): return np.array([int(self._prec2 is None), 0, 0, 0, 0]) @property def sum_cov(self): if self._sum_cov is None: self._sum_cov = self.cov1 + self.cov2 return self._sum_cov def _sum_cov_cost(self): if self._sum_cov is None: return np.array([0, 0, 0, 1, 0]) + self._cov1_cost() + self._cov2_cost() else: return np.array([0, 0, 0, 0, 0]) @property def sum_cov_inv(self): if self._sum_cov_inv is None: self._sum_cov_inv = torch.inverse(self.sum_cov) return self._sum_cov_inv def _sum_cov_inv_cost(self): if self._sum_cov_inv is None: return np.array([1, 0, 0, 0, 0]) + self._sum_cov_cost() else: return np.array([0, 0, 0, 0, 0]) @property def prec(self): if self._prec is None: if self._cost_lt(self._prec_v1_cost(), self._prec_v2_cost()): self._prec = self.prec1 + self.prec2 else: self._prec = torch.inverse(self.cov) return self._prec def _prec_v1_cost(self): """Cost of computing precision by adding prec1 and prec2""" return np.array([0, 0, 0, 1, 0]) + self._prec1_cost() + self._prec2_cost() def _prec_v2_cost(self): """ Cost of computing precision by inverting cov (if cov is not present assume infinite costsfor tie-breaking). """ if self._cov is None: return np.ones(5) * np.inf else: return np.array([1, 0, 0, 0, 0]) def _prec_cost(self): if self._prec is None: return np.minimum(self._prec_v1_cost(), self._prec_v2_cost()) else: return np.array([0, 0, 0, 0, 0]) def _cov_v1(self): """Compute covariance from prec1 and prec2""" return torch.inverse(self.prec) def _cov_v1_cost(self): return np.array([1, 0, 0, 0, 0]) + self._prec_cost() def _cov_v2(self): """Compute covariance from cov1 and cov2""" return torch.matmul(torch.matmul(self.cov1, self.sum_cov_inv), self.cov2) def _cov_v2_cost(self): return np.array([0, 2, 0, 0, 0]) + self._cov1_cost() + self._cov2_cost() + self._sum_cov_inv_cost() @property def cov(self): """Compute covariance matrix using the least expensive method""" if self._cov is None: if self._cost_lt(self._cov_v1_cost(), self._cov_v2_cost()): self._cov = self._cov_v1() else: self._cov = self._cov_v2() return self._cov def _cov_cost(self): if self._cov is None: return np.minimum(self._cov_v1_cost(), self._cov_v2_cost()) else: return np.array([0, 0, 0, 0, 0]) def _mean_v1(self): """Compute mean from cov, prec1, and prec2""" return torch.einsum('...ab,...b->...a', self.cov, torch.einsum('...ab,...b->...a', self.prec1, self.mean1) + torch.einsum('...ab,...b->...a', self.prec2, self.mean2)) def _mean_v1_cost(self): return np.array([0, 0, 3, 0, 1]) + self._cov_cost() + self._prec1_cost() + self._prec2_cost() def _mean_v2(self): """Compute mean from cov, prec1, and prec2""" return torch.einsum('...ab,...b->...a', self.cov2, torch.einsum('...ab,...b->...a', self.sum_cov_inv, self.mean1)) + \ torch.einsum('...ab,...b->...a', self.cov1, torch.einsum('...ab,...b->...a', self.sum_cov_inv, self.mean2)) def _mean_v2_cost(self): return np.array([0, 0, 4, 0, 1]) + self._sum_cov_inv_cost() + self._cov1_cost() + self._cov2_cost() @cached_property def mean(self): """Compute mean using the least expensive method""" if self._cost_lt(self._mean_v1_cost(), self._mean_v2_cost()): return self._mean_v1() else: return self._mean_v2() @cached_property def torch(self): """Return the product distribution as torch MultivariateNormal""" if self._torch is None: self._torch = TorchMultivariateNormal(loc=self.mean, covariance_matrix=self.cov) return self._torch @cached_property def log_norm(self): """Compute log-normalisation using the least expensive method""" # if both have equal costs use precision matrix, which is less expensive within TorchMultivariateNormal if self._cost_lt(self._sum_cov_cost(), self._sum_cov_inv_cost()): return TorchMultivariateNormal(loc=self.mean1, covariance_matrix=self.sum_cov).log_prob(self.mean2) else: return TorchMultivariateNormal(loc=self.mean1, precision_matrix=self.sum_cov_inv).log_prob(self.mean2)
[docs] class Product: def __init__(self, means: torch.Tensor, covariances: torch.Tensor = None, precisions: torch.Tensor = None, determinants: torch.Tensor = None, scaled_means: torch.Tensor = None, method: str = None): """ Represents the product of N multivariate normal distributions over the same random variable. The inputs may additionally have an arbitrary number of batch dimensions (indicated as '...' below). The means and either the covariances or precisions have to be provided, the remaining parameters are computed from them (providing them will avoid recomputation). The 'method' argument determines which method is used for computing the normalisation factor. Methods return a triplet (log scaling factor (...), mean (...xD), covariance matrix (...xDxD) of the product), which are also available as log_norm, mean, covariance properties, respectively. :param means: Nx...xD array of means :param covariances: Nx...xDxD array of covariance matrices :param precisions: Nx...xDxD array of precision matrices :param determinants: Nx... array with determinants of the covariance matrices :param scaled_means: Nx...xD array with products of precision matrices and means :param method: method to use for computing the scaling factor (None/'default', 'iter', 'pair', 'commute') """ self._means = means self.N = self._means.shape[0] self.D = self._means.shape[-1] # make sure we have both the covariance and precision matrices assert covariances is not None or precisions is not None if covariances is not None: self._covariances = covariances self._precisions = torch.inverse(covariances) else: self._covariances = torch.inverse(precisions) self._precisions = precisions # compute determinants if determinants is None: self._determinants = torch.det(self._covariances) else: self._determinants = determinants # compute scaled means if scaled_means is None: self._scaled_means = torch.einsum('n...ab,n...b->n...a', self._precisions, self._means) else: self._scaled_means = scaled_means # compute parameters of product distribution self.precision = self._precisions.sum(dim=0) self.covariance = torch.inverse(self.precision) self.mean = torch.einsum('...ab,...b->...a', self.covariance, self._scaled_means.sum(dim=0)) self.det = torch.det(self.covariance) # compute normalisation factor if method is None or method == 'default': self.log_norm, _, _ = self.product() elif method == 'iter': self.log_norm, _, _ = self.iter_product(means=self._means, covariances=self._covariances) elif method == 'pair': self.log_norm = PairwiseProduct(mean1=self._means[0], cov1=self._covariances[0], mean2=self._means[1], cov2=self._covariances[1]).log_norm elif method == 'commute': self.log_norm, _, _ = self.commuting_product() else: raise ValueError(f"Unknown method '{method}'")
[docs] def product(self): quad_factor = torch.einsum('n...a,n...a->n...', self._scaled_means, self._means).sum(dim=0) mixed_factor = torch.einsum('n...a,m...a->nm...', torch.einsum('...ab,n...b->n...a', self.covariance, self._scaled_means), self._scaled_means).sum(dim=(0, 1)) exp_factor = -(quad_factor - mixed_factor) / 2 div = self.det.log() - self._determinants.log().sum(dim=0) pi = math.log(2 * math.pi) * (-self.D * (self.N - 1)) det_factor = (pi + div) / 2 return det_factor + exp_factor, self.mean, self.covariance
[docs] @classmethod def iter_product(cls, means, covariances): ret_log_fac = 0 ret_mean = None ret_cov = None for m, cov in zip(means, covariances): if ret_mean is None and ret_cov is None: ret_mean = m ret_cov = cov continue pp = PairwiseProduct(mean1=ret_mean, cov1=ret_cov, mean2=m, cov2=cov) ret_log_fac += pp.log_norm ret_mean = pp.mean ret_cov = pp.cov return ret_log_fac, ret_mean, ret_cov
[docs] def commuting_product(self): exp_factor = 0 for i in range(self.N): for j in range(i + 1, self.N): cov_ij = self._covariances[j].matmul(self.precision).matmul(self._covariances[i]) exp_factor += TorchMultivariateNormal(self._means[i], cov_ij).log_prob(self._means[j]) det_factor = (math.log(2 * math.pi) * (self.D * (self.N - 1) * (self.N - 2) / 2) + (self._determinants.log() * (self.N - 2)).sum(dim=0) - self.det.log() * (self.N * (self.N - 1) / 2 - 1) ) / 2 return det_factor + exp_factor, self.mean, self.covariance
[docs] class ApproximateMixture: def __init__(self, means: Union[torch.Tensor, Iterable[torch.Tensor]], log_weights: Union[torch.Tensor, Iterable[torch.Tensor]] = None, covariances: Union[torch.Tensor, Iterable[torch.Tensor]] = None, cat=False): """ Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood if means are data points). :param means: (N,X,D) array with means / locations of data points; X are arbitrary batch dimensions :param log_weights: (N,X) array of weights of the components (optional; default is to assume uniform weights) :param covariances: (N,X,D,D) array of covariance matrices of the components (optional; default is to assume zero covariance, which corresponds to the components being treated as single data points) :param cat: If true assume the inputs are iterables of tensors, which need to be concatenated first """ # first concatenate inputs if requested if cat: means = torch.cat(tuple(means), dim=0) if log_weights is not None: log_weights = torch.cat(tuple(log_weights), dim=0) if covariances is not None: covariances = torch.cat(tuple(covariances), dim=0) # remember means self._means = means # get dimensions self.N = self._means.shape[0] self.D = self._means.shape[-1] # init uniform weights if not provided; get normalisation if log_weights is None: log_weights = (torch.ones(self._means.shape[:-1]) / self.N).log() self.log_norm = log_weights.logsumexp(dim=0) self.norm_log_weights = log_weights - self.log_norm self.norm_weights = self.norm_log_weights.exp() # compute mean self.mean = (self.norm_weights[:, ..., None] * self._means).sum(dim=0) # compute covariance component of means diff = self._means - self.mean[None, ..., :] mean_cov = (self.norm_weights[..., None, None] * diff[..., :, None] * diff[..., None, :]).sum(dim=0) # compute covariance if covariances is None: # only means component is non-zero (case for single data points) self.covariance = mean_cov else: # compute component of covariances and add to component of means self._cov_cov = (self.norm_weights[:, ..., None, None] * covariances).sum(dim=(0)) self.covariance = mean_cov + self._cov_cov