Source code for rbnet.autoencoder

import torch

from rbnet.base import Prior, NonTermVar, Transition, Cell
from rbnet.util import TupleTMap


[docs] class AutoencoderNonTermVar(NonTermVar): def __init__(self, dim, chart_type="TMap", *args, **kwargs): """ A point-wise continuous non-terminal variable of dimensionality ``dim``. Distributions over these variables can be thought of as Dirac deltas (the limit of infinitely narrow Gaussians), represented by a location (a specific variable value) and weight (in the case of mixtures or inside probabilities). :param cardinality: cardinality :param chart_type: type of chart to use ("dict" or "TMap") """ super().__init__(*args, **kwargs) self.dim = dim self.chart_type = chart_type
[docs] def get_chart(self, n, *args, **kwargs): """ Initialise a chart for sequence of length `n`. :param n: length of the sequence :return: chart """ if self.chart_type == "dict": return {} elif self.chart_type == "TMap": return TupleTMap([ torch.zeros((TupleTMap.size_from_n(n), self.dim)), # the variable values torch.zeros(TupleTMap.size_from_n(n)) # the inside probabilities ]) else: raise ValueError(f"Unknown chart type '{self.chart_type}'")
[docs] def mixture(self, components, weights=None, dim=0): """ Approximate a mixture by its weighted average. The new weight is the sum of mixture weights. Mixture weights are provided as part of the ``components``; additional weights provided as ``weights`` are multiplied on the weights provided in ``components``. :param components: array-like with pairs of (values, weights) mixture components along ``dim`` :param weights: [optional] weights of the mixture components; must be compatible (broadcastable) to weights in ``components`` :param dim: integer or tuple of integers indicating the dimensions of ``components`` along which to sum to compute the mixture :return: distribution corresponding to the mixture """ if len(components) == 0: return torch.zeros(self.dim), torch.zeros(1) mix_weights = torch.stack([c[1] for c in components]) components = torch.stack([c[0] for c in components]) if not isinstance(dim, tuple): dim = (dim,) if weights is not None: mix_weights = torch.as_tensor(weights) * mix_weights return (mix_weights * components).sum(dim=dim), mix_weights.sum(dim=dim)
[docs] class AutoencoderTransition(Transition): r""" An autoencoder transition combining a deterministic binary non-terminal and unary terminal transition. The general :meth:`~rbnet.base.Transition.inside_marginals` simplify for autoencoders. First, we operate on point estimates (delta distributions), so we assume the following form for the inside distribution .. math:: \beta_{i:k}(x_{i:k}) &:= w_{i:k} \, \delta(x_{i:k}=\bar{x}_{i:k}) \\ \widetilde{\beta}_{i:j:k}(x_{i:k}) &:= \widetilde{w}_{i:j:k} \, \delta(x_{i:k}=\widetilde{x}_{i:j:k})~, where :math:`\bar{x}_{i:k}` and :math:`\widetilde{x}_{i:j:k}` define the location of the delta distributions and :math:`w_{i:k}` and :math:`\widetilde{w}_{i:j:k}` their norm. For binary non-terminals, we then get .. math:: \widetilde{\beta}_{i:j:k}(x_{i:k}) &= \int\int p_{N}(x_{i:j}, x_{j:k} \mid x_{i:k}) \beta(x_{i:j}) \beta(x_{j:k}) dx_{i:j} dx_{j:k} \\ &= p_{N}(\bar{x}_{i:j}, \bar{x}_{j:k} \mid x_{i:k}) \, w_{i:j} \, w_{j:k} and for unary terminals .. math:: \widetilde{\beta}_{i:j:k}(x_{i:i+1}) = p_{T}(y_{i+1} \mid x_{i:i+1})~. We now recover the form assumed above by fixing the value of :math:`x_{i:k}` and :math:`x_{i:i+1}` given by a deterministic encoder, while the transition probabilities are provided by the stochastic forward model, i.e., the decoder .. math:: \widetilde{x}_{i:j:k} &:= \mbox{non-terminal encoder}(\bar{x}_{i:j}, \bar{x}_{j:k}) \\ p_{N}(\bar{x}_{i:j}, \bar{x}_{j:k} \mid x_{i:k}) &:= \mbox{non-terminal decoder}(\bar{x}_{i:j}, \bar{x}_{j:k} \mid \widetilde{x}_{i:j:k}) \\ \widetilde{x}_{i:i+1} &:= \mbox{terminal encoder}(y_{i+1}) \\ p_{T}(y_{i+1} \mid x_{i:i+1}) &:= \mbox{terminal decoder}(y_{i+1} \mid \widetilde{x}_{i:i+1})~. """ def __init__(self, terminal_encoder, terminal_decoder, non_terminal_encoder, non_terminal_decoder, left_idx=0, right_idx=0, *args, **kwargs): super().__init__(*args, **kwargs) self.left_idx = left_idx self.right_idx = right_idx self.terminal_encoder = terminal_encoder self.terminal_decoder = terminal_decoder self.non_terminal_encoder = non_terminal_encoder self.non_terminal_decoder = non_terminal_decoder
[docs] def inside_marginals(self, location, inside_chart, terminal_chart, value=None, **kwargs): if value is not None: NotImplementedError("Conditional inside probabilities currently not implemented") if isinstance(location, tuple) and len(location) == 2: start, end = location if end - start <= 1: # terminal transition parent_var = self.terminal_encoder(terminal_chart[start]) transition_prob = self.terminal_decoder(parent_var, terminal_chart[start]) return [(parent_var, transition_prob)] else: inside_marginals = [] for split in range(start + 1, end): # get inside probabilities (clone to avoid problems with inplace operations below – not sure why) left_var, left_inside = inside_chart[self.left_idx][start, split] right_var, right_inside = inside_chart[self.right_idx][split, end] parent_var = self.non_terminal_encoder(left_var, right_var) transition_prob = self.non_terminal_decoder(parent_var, left_var, right_var) inside_marginals.append((parent_var, transition_prob * left_inside * right_inside)) return inside_marginals else: raise ValueError(f"Expected locations to be (start, end) index, but got: {location}")
[docs] class AutoencoderCell(Cell): def __init__(self, variable, transition, *args, **kwargs): """ :param variable: the :class:`~AutoencoderVariable` for this cell :param transition: the :class:`~AutoencoderTransition` for this cell :param args: passed on to super().__init__ :param kwargs: passed on to super().__init__ """ super().__init__(variable=variable, *args, **kwargs) self._transition = transition
[docs] def transitions(self): yield from [self._transition]
[docs] def inside_mixture(self, inside_marginals): assert len(inside_marginals) == 1, f"Expected only one element for a single transition, but got {len(inside_marginals)}" return self.variable.mixture(components=inside_marginals[0])
[docs] class AutoencoderPrior(Prior): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def marginal_likelihood(self, root_location, inside_chart, **kwargs): assert len(inside_chart) == 1, f"Expected inside chart with one variable, but got {len(inside_chart)}" return inside_chart[0][root_location]