#  Copyright (c) 2021 Robert Lieck
import math
from functools import cached_property
import operator
from typing import Union
from collections.abc import Iterable
import torch
from torch.distributions import MultivariateNormal as TorchMultivariateNormal
import numpy as np
[docs]
class MultivariateNormal:
    """
    Represents a random variable with a multivariate normal distribution. It wraps PyTorch's
    :class:`torch.distributions.MultivariateNormal` (available via the :attr:`torch` property) but extends it with
    additional functionality.
    """
    __FULL__ = "__FULL__"
    __DIAG__ = "__DIAG__"
    __SCAL__ = "__SCAL__"
    @classmethod
    def _expand_scalar(cls, diag, dim):
        """
        Expand an array of scalars into an array of diagonal matrices.
        :param diag: Array of arbitrary shape (...) with scalars that specify the value on the diagonal of the matrices.
        :param dim: Dimensionality of the resulting matrices.
        :return: Array of shape (..., dim, dim) with diagonal matrices.
        .. testsetup::
            from rbnet.multivariate_normal import MultivariateNormal
            import torch
        .. testcode::
            diag = torch.arange(3, dtype=torch.float)
            print(diag)
            print(MultivariateNormal._expand_scalar(diag=diag, dim=3))
        .. testoutput::
            tensor([0., 1., 2.])
            tensor([[[0., 0., 0.],
                     [0., 0., 0.],
                     [0., 0., 0.]],
                    [[1., 0., 0.],
                     [0., 1., 0.],
                     [0., 0., 1.]],
                    [[2., 0., 0.],
                     [0., 2., 0.],
                     [0., 0., 2.]]])
        """
        diag_ = torch.zeros(diag.shape + (dim, dim))
        diag_[..., torch.arange(dim), torch.arange(dim)] = diag[..., None]
        return diag_
    @classmethod
    def _expand_diagonal(cls, diag, dim=None):
        """
        Expand an array of diagonals into an array of diagonal matrices.
        :param diag: Array of shape (..., dim) with entries that specify the values on the diagonal of the matrices.
        :param dim: Dimensionality of the resulting matrices.
        :return: Array of shape (..., dim, dim) with diagonal matrices.
        """
        if dim is None:
            dim = diag.shape[-1]
        else:
            if diag.shape[-1] != dim:
                raise ValueError(f"Last dimension of tensor has size {diag.shape[-1]}, "
                                 f"cannot expand to diagonal of size {dim}")
        mat_ = torch.zeros(diag.shape + (dim,))
        mat_[..., torch.arange(dim), torch.arange(dim)] = diag
        return mat_
    @classmethod
    def _to_tensor(cls, t):
        """
        Return `t` as a pytorch tensor. If `t` already is a tensor it is returned as is, otherwise ``torch.tensor(t)``
        is returned.
        """
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        return t
    def __init__(self, loc: torch.Tensor,
                 covariance_matrix: Union[torch.Tensor, None] = None,
                 precision_matrix: Union[torch.Tensor, None] = None,
                 scale_tril: Union[torch.Tensor, None] = None,
                 norm: Union[torch.Tensor, None] = None,
                 log_norm: Union[torch.Tensor, None] = None,
                 validate_args: Union[bool, None] = None,
                 dim: Union[int, bool] = False):
        """
        Initialise variable with given mean and covariance or precision matrix. :attr:`loc`
        :param loc:
        :param covariance_matrix:
        :param precision_matrix:
        :param scale_tril:
        :param norm:
        :param log_norm:
        :param validate_args:
        :param dim: explicitly specify the event dimensionality if `loc` is a (batch of) scalars; otherwise this is
         taken to correspond to the last axis of `loc`
        """
        # make sure at most one of 'norm' and 'log_norm' is provided
        if (norm is not None) + (log_norm is not None) > 1:
            raise ValueError("At most one of 'norm' and 'log_norm' may be specified.")
        if norm is None and log_norm is None:
            self._log_norm = self._to_tensor(0.)
        elif log_norm is not None:
            self._log_norm = self._to_tensor(log_norm)
        else:  # norm is not None
            self._log_norm = np.log(self._to_tensor(norm))
        # get the location
        self._loc = self._to_tensor(loc)
        if dim:
            # event dimensionality explicitly defined
            self._event_dim = dim
            # expand scalar locations accordingly
            self._expanded_loc = self._loc[..., None].expand(self._loc.shape + (dim,))
        else:
            # event dimensionality given by last axis of loc
            if self._loc.dim() < 1:
                # loc is scalar
                self._event_dim = self._loc.shape
            else:
                # loc is scalar
                self._event_dim = self._loc.shape[-1]
            self._expanded_loc = self._loc
        if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
            raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril must be specified.")
        if scale_tril is not None:
            self._scale_tril = mat = self._to_tensor(scale_tril)
        elif covariance_matrix is not None:
            self._covariance_matrix = mat = self._to_tensor(covariance_matrix)
        else:
            self._precision_matrix = mat = self._to_tensor(precision_matrix)
        self._mat_type = [self.__SCAL__, self.__DIAG__, self.__FULL__][mat.dim() - self._expanded_loc.dim() + 1]
        self._batch_shape = torch.broadcast_shapes(mat.shape[:len(self._expanded_loc.shape) - 1],
                                                   self._expanded_loc.shape[:-1])
        self._validate_args = validate_args
        self._torch = None
    def _expand_mat(self, mat: torch.Tensor) -> Union[torch.Tensor, None]:
        if self._mat_type == self.__SCAL__:
            mat = self._expand_scalar(mat, self._event_dim)
        elif self._mat_type == self.__DIAG__:
            mat = self._expand_diagonal(mat, self._event_dim)
        return mat
    @property
    def torch(self):
        if self._torch is None:
            if hasattr(self, "_scale_tril"):
                self._torch = TorchMultivariateNormal(loc=self._expanded_loc,
                                                      scale_tril=self._expand_mat(self._scale_tril),
                                                      validate_args=self._validate_args)
            elif hasattr(self, "_covariance_matrix"):
                self._torch = TorchMultivariateNormal(loc=self._expanded_loc,
                                                      covariance_matrix=self._expand_mat(self._covariance_matrix),
                                                      validate_args=self._validate_args)
            elif hasattr(self, "_precision_matrix"):
                self._torch = TorchMultivariateNormal(loc=self._expanded_loc,
                                                      precision_matrix=self._expand_mat(self._precision_matrix),
                                                      validate_args=self._validate_args)
        return self._torch
    def __mul__(self, other):
        if not isinstance(other, MultivariateNormal):
            return NotImplemented
        else:
            prod = PairwiseProduct(mean1=self._loc, cov1=self.torch.covariance_matrix,
                                   mean2=other._loc, cov2=other.torch.covariance_matrix)
            return MultivariateNormal(loc=prod.mean,
                                      scale_tril=prod.torch.scale_tril,
                                      log_norm=prod.log_norm + self._log_norm + other._log_norm) 
[docs]
class PairwiseProduct:
    def __init__(self, mean1, mean2, cov1=None, cov2=None, prec1=None, prec2=None):
        assert not (cov1 is None and prec1 is None), "either cov1 or prec1 needs to be provided"
        assert not (cov2 is None and prec2 is None), "either cov2 or prec2 needs to be provided"
        self.mean1 = mean1
        self.mean2 = mean2
        self._cov1 = cov1
        self._cov2 = cov2
        self._prec1 = prec1
        self._prec2 = prec2
        self._sum_cov = None
        self._sum_cov_inv = None
        self._cov = None
        self._prec = None
        self._torch = None
    @classmethod
    def _cost_comp(cls, cost1, cost2, comp):
        """
        Compares cost1 and cost2 using comp. Costs are numpy arrays [i, mm, mv, ma, va] with i indicating the number of
        matrix inversion, mm that of matrix-matrix multiplications, mv that of matrix-vector multiplications, ma that
        of matrix-matrix additions, and va that of vector-vector additions. Comparison is performed element-wise in this
        order, that is::
            if comp(i1, i2):
                return True
            elif not i1 == i2:
                return False
        continue with mm1 and mm2 etc...
        """
        for c1, c2 in zip(cost1, cost2):
            if comp(c1, c2):
                return True
            elif not c1 == c2:
                return False
        return False
    @classmethod
    def _cost_lt(cls, cost1, cost2):
        return cls._cost_comp(cost1=cost1, cost2=cost2, comp=operator.lt)
    @property
    def cov1(self):
        if self._cov1 is None:
            self._cov1 = torch.inverse(self.prec1)
        return self._cov1
    def _cov1_cost(self):
        return np.array([int(self._cov1 is None), 0, 0, 0, 0])
    @property
    def cov2(self):
        if self._cov2 is None:
            self._cov2 = torch.inverse(self.prec2)
        return self._cov2
    def _cov2_cost(self):
        return np.array([int(self._cov2 is None), 0, 0, 0, 0])
    @property
    def prec1(self):
        if self._prec1 is None:
            self._prec1 = torch.inverse(self.cov1)
        return self._prec1
    def _prec1_cost(self):
        return np.array([int(self._prec1 is None), 0, 0, 0, 0])
    @property
    def prec2(self):
        if self._prec2 is None:
            self._prec2 = torch.inverse(self.cov2)
        return self._prec2
    def _prec2_cost(self):
        return np.array([int(self._prec2 is None), 0, 0, 0, 0])
    @property
    def sum_cov(self):
        if self._sum_cov is None:
            self._sum_cov = self.cov1 + self.cov2
        return self._sum_cov
    def _sum_cov_cost(self):
        if self._sum_cov is None:
            return np.array([0, 0, 0, 1, 0]) + self._cov1_cost() + self._cov2_cost()
        else:
            return np.array([0, 0, 0, 0, 0])
    @property
    def sum_cov_inv(self):
        if self._sum_cov_inv is None:
            self._sum_cov_inv = torch.inverse(self.sum_cov)
        return self._sum_cov_inv
    def _sum_cov_inv_cost(self):
        if self._sum_cov_inv is None:
            return np.array([1, 0, 0, 0, 0]) + self._sum_cov_cost()
        else:
            return np.array([0, 0, 0, 0, 0])
    @property
    def prec(self):
        if self._prec is None:
            if self._cost_lt(self._prec_v1_cost(), self._prec_v2_cost()):
                self._prec = self.prec1 + self.prec2
            else:
                self._prec = torch.inverse(self.cov)
        return self._prec
    def _prec_v1_cost(self):
        """Cost of computing precision by adding prec1 and prec2"""
        return np.array([0, 0, 0, 1, 0]) + self._prec1_cost() + self._prec2_cost()
    def _prec_v2_cost(self):
        """
        Cost of computing precision by inverting cov (if cov is not present assume infinite costsfor tie-breaking).
        """
        if self._cov is None:
            return np.ones(5) * np.inf
        else:
            return np.array([1, 0, 0, 0, 0])
    def _prec_cost(self):
        if self._prec is None:
            return np.minimum(self._prec_v1_cost(), self._prec_v2_cost())
        else:
            return np.array([0, 0, 0, 0, 0])
    def _cov_v1(self):
        """Compute covariance from prec1 and prec2"""
        return torch.inverse(self.prec)
    def _cov_v1_cost(self):
        return np.array([1, 0, 0, 0, 0]) + self._prec_cost()
    def _cov_v2(self):
        """Compute covariance from cov1 and cov2"""
        return torch.matmul(torch.matmul(self.cov1, self.sum_cov_inv), self.cov2)
    def _cov_v2_cost(self):
        return np.array([0, 2, 0, 0, 0]) + self._cov1_cost() + self._cov2_cost() + self._sum_cov_inv_cost()
    @property
    def cov(self):
        """Compute covariance matrix using the least expensive method"""
        if self._cov is None:
            if self._cost_lt(self._cov_v1_cost(), self._cov_v2_cost()):
                self._cov = self._cov_v1()
            else:
                self._cov = self._cov_v2()
        return self._cov
    def _cov_cost(self):
        if self._cov is None:
            return np.minimum(self._cov_v1_cost(), self._cov_v2_cost())
        else:
            return np.array([0, 0, 0, 0, 0])
    def _mean_v1(self):
        """Compute mean from cov, prec1, and prec2"""
        return torch.einsum('...ab,...b->...a', self.cov,
                            torch.einsum('...ab,...b->...a', self.prec1, self.mean1) +
                            torch.einsum('...ab,...b->...a', self.prec2, self.mean2))
    def _mean_v1_cost(self):
        return np.array([0, 0, 3, 0, 1]) + self._cov_cost() + self._prec1_cost() + self._prec2_cost()
    def _mean_v2(self):
        """Compute mean from cov, prec1, and prec2"""
        return torch.einsum('...ab,...b->...a', self.cov2,
                            torch.einsum('...ab,...b->...a', self.sum_cov_inv, self.mean1)) + \
               
torch.einsum('...ab,...b->...a', self.cov1,
                            torch.einsum('...ab,...b->...a', self.sum_cov_inv, self.mean2))
    def _mean_v2_cost(self):
        return np.array([0, 0, 4, 0, 1]) + self._sum_cov_inv_cost() + self._cov1_cost() + self._cov2_cost()
    @cached_property
    def mean(self):
        """Compute mean using the least expensive method"""
        if self._cost_lt(self._mean_v1_cost(), self._mean_v2_cost()):
            return self._mean_v1()
        else:
            return self._mean_v2()
    @cached_property
    def torch(self):
        """Return the product distribution as torch MultivariateNormal"""
        if self._torch is None:
            self._torch = TorchMultivariateNormal(loc=self.mean, covariance_matrix=self.cov)
        return self._torch
    @cached_property
    def log_norm(self):
        """Compute log-normalisation using the least expensive method"""
        # if both have equal costs use precision matrix, which is less expensive within TorchMultivariateNormal
        if self._cost_lt(self._sum_cov_cost(), self._sum_cov_inv_cost()):
            return TorchMultivariateNormal(loc=self.mean1, covariance_matrix=self.sum_cov).log_prob(self.mean2)
        else:
            return TorchMultivariateNormal(loc=self.mean1, precision_matrix=self.sum_cov_inv).log_prob(self.mean2) 
[docs]
class Product:
    def __init__(self,
                 means: torch.Tensor,
                 covariances: torch.Tensor = None,
                 precisions: torch.Tensor = None,
                 determinants: torch.Tensor = None,
                 scaled_means: torch.Tensor = None,
                 method: str = None):
        """
        Represents the product of N multivariate normal distributions over the same random variable. The inputs
        may additionally have an arbitrary number of batch dimensions (indicated as '...' below). The means and
        either the covariances or precisions have to be provided, the remaining parameters are computed from them
        (providing them will avoid recomputation). The 'method' argument determines which method is used for
        computing the normalisation factor. Methods return a triplet (log scaling factor (...), mean (...xD), covariance
        matrix (...xDxD) of the product), which are also available as log_norm, mean, covariance properties, respectively.
        :param means: Nx...xD array of means
        :param covariances: Nx...xDxD array of covariance matrices
        :param precisions: Nx...xDxD array of precision matrices
        :param determinants: Nx... array with determinants of the covariance matrices
        :param scaled_means: Nx...xD array with products of precision matrices and means
        :param method: method to use for computing the scaling factor (None/'default', 'iter', 'pair', 'commute')
        """
        self._means = means
        self.N = self._means.shape[0]
        self.D = self._means.shape[-1]
        # make sure we have both the covariance and precision matrices
        assert covariances is not None or precisions is not None
        if covariances is not None:
            self._covariances = covariances
            self._precisions = torch.inverse(covariances)
        else:
            self._covariances = torch.inverse(precisions)
            self._precisions = precisions
        # compute determinants
        if determinants is None:
            self._determinants = torch.det(self._covariances)
        else:
            self._determinants = determinants
        # compute scaled means
        if scaled_means is None:
            self._scaled_means = torch.einsum('n...ab,n...b->n...a', self._precisions, self._means)
        else:
            self._scaled_means = scaled_means
        # compute parameters of product distribution
        self.precision = self._precisions.sum(dim=0)
        self.covariance = torch.inverse(self.precision)
        self.mean = torch.einsum('...ab,...b->...a', self.covariance, self._scaled_means.sum(dim=0))
        self.det = torch.det(self.covariance)
        # compute normalisation factor
        if method is None or method == 'default':
            self.log_norm, _, _ = self.product()
        elif method == 'iter':
            self.log_norm, _, _ = self.iter_product(means=self._means, covariances=self._covariances)
        elif method == 'pair':
            self.log_norm = PairwiseProduct(mean1=self._means[0], cov1=self._covariances[0],
                                            mean2=self._means[1], cov2=self._covariances[1]).log_norm
        elif method == 'commute':
            self.log_norm, _, _ = self.commuting_product()
        else:
            raise ValueError(f"Unknown method '{method}'")
[docs]
    def product(self):
        quad_factor = torch.einsum('n...a,n...a->n...', self._scaled_means, self._means).sum(dim=0)
        mixed_factor = torch.einsum('n...a,m...a->nm...',
                                    torch.einsum('...ab,n...b->n...a', self.covariance, self._scaled_means),
                                    self._scaled_means).sum(dim=(0, 1))
        exp_factor = -(quad_factor - mixed_factor) / 2
        div = self.det.log() - self._determinants.log().sum(dim=0)
        pi = math.log(2 * math.pi) * (-self.D * (self.N - 1))
        det_factor = (pi + div) / 2
        return det_factor + exp_factor, self.mean, self.covariance 
[docs]
    @classmethod
    def iter_product(cls, means, covariances):
        ret_log_fac = 0
        ret_mean = None
        ret_cov = None
        for m, cov in zip(means, covariances):
            if ret_mean is None and ret_cov is None:
                ret_mean = m
                ret_cov = cov
                continue
            pp = PairwiseProduct(mean1=ret_mean, cov1=ret_cov, mean2=m, cov2=cov)
            ret_log_fac += pp.log_norm
            ret_mean = pp.mean
            ret_cov = pp.cov
        return ret_log_fac, ret_mean, ret_cov 
[docs]
    def commuting_product(self):
        exp_factor = 0
        for i in range(self.N):
            for j in range(i + 1, self.N):
                cov_ij = self._covariances[j].matmul(self.precision).matmul(self._covariances[i])
                exp_factor += TorchMultivariateNormal(self._means[i], cov_ij).log_prob(self._means[j])
        det_factor = (math.log(2 * math.pi) * (self.D * (self.N - 1) * (self.N - 2) / 2) +
                      (self._determinants.log() * (self.N - 2)).sum(dim=0) -
                      self.det.log() * (self.N * (self.N - 1) / 2 - 1)
                      ) / 2
        return det_factor + exp_factor, self.mean, self.covariance 
 
[docs]
class ApproximateMixture:
    def __init__(self,
                 means: Union[torch.Tensor, Iterable[torch.Tensor]],
                 log_weights: Union[torch.Tensor, Iterable[torch.Tensor]] = None,
                 covariances: Union[torch.Tensor, Iterable[torch.Tensor]] = None,
                 cat=False):
        """
        Approximate a mixture of N multivariate normal distributions with a single one by matching moments (equivalent to
        minimising the KL-divergence or cross-entropy from the mixture to the approximation, and the neg-log-likelihood
        if means are data points).
        :param means: (N,X,D) array with means / locations of data points; X are arbitrary batch dimensions
        :param log_weights: (N,X) array of weights of the components (optional; default is to assume uniform weights)
        :param covariances: (N,X,D,D) array of covariance matrices of the components (optional; default is to assume
         zero covariance, which corresponds to the components being treated as single data points)
        :param cat: If true assume the inputs are iterables of tensors, which need to be concatenated first
        """
        # first concatenate inputs if requested
        if cat:
            means = torch.cat(tuple(means), dim=0)
            if log_weights is not None:
                log_weights = torch.cat(tuple(log_weights), dim=0)
            if covariances is not None:
                covariances = torch.cat(tuple(covariances), dim=0)
        # remember means
        self._means = means
        # get dimensions
        self.N = self._means.shape[0]
        self.D = self._means.shape[-1]
        # init uniform weights if not provided; get normalisation
        if log_weights is None:
            log_weights = (torch.ones(self._means.shape[:-1]) / self.N).log()
        self.log_norm = log_weights.logsumexp(dim=0)
        self.norm_log_weights = log_weights - self.log_norm
        self.norm_weights = self.norm_log_weights.exp()
        # compute mean
        self.mean = (self.norm_weights[:, ..., None] * self._means).sum(dim=0)
        # compute covariance component of means
        diff = self._means - self.mean[None, ..., :]
        mean_cov = (self.norm_weights[..., None, None] * diff[..., :, None] * diff[..., None, :]).sum(dim=0)
        # compute covariance
        if covariances is None:
            # only means component is non-zero (case for single data points)
            self.covariance = mean_cov
        else:
            # compute component of covariances and add to component of means
            self._cov_cov = (self.norm_weights[:, ..., None, None] * covariances).sum(dim=(0))
            self.covariance = mean_cov + self._cov_cov