Source code for rbnet.base

from typing import Iterable
from abc import ABC, abstractmethod

import torch

from rbnet.util import ConstrainedModuleList


[docs] class RBN(ABC): """Base class for RBNs.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] @abstractmethod def inside_schedule(self, *args, **kwargs): r""" Iterate through (batches of) non-terminal locations for computing inside probabilities. The iteration order has to respect dependencies, i.e., non-terminals :math:`x'` that depend on another non-terminal :math:`x` (in the generative direction, i.e. :math:`x` generates :math:`x'`: :math:`x \rightarrow x'`) have to be iterated BEFORE the non-terminals they depend on (:math:`x'` has to be visited before :math:`x`). :return: location or batch of locations """ raise NotImplementedError
@property @abstractmethod def root_location(self): """ Return the location of the root variables. This is typically the value returned in the last iteration of :meth:`inside_schedule` :return: location of the root variables """ raise NotImplementedError
[docs] @abstractmethod def cells(self): """ Return iterable over cells (corresponding to the non-terminal variables). """ raise NotImplementedError
[docs] def init_inside(self, *args, **kwargs): """ Initialise for parsing a new input. This function is called by :meth:`inside` with all provided parameters. Derived classes may override it to implement tasks such as initialising charts for the non-terminal variables. """ pass
[docs] @abstractmethod def update_inside_chart(self, var_idx, locations, values): """ For the specified variable, update the chart for inside probabilities with given values at given locations. :param var_idx: specifies the variable :param locations: locations to update :param values: values to store in chart """ raise NotImplementedError
@property @abstractmethod def inside_chart(self): """ Return the chart with inside probabilities for all variables. :return: inside chart """ raise NotImplementedError @property @abstractmethod def terminal_chart(self): """ Return the chart with terminal variables. :return: terminal chart """ raise NotImplementedError @property @abstractmethod def prior(self): """ Return the prior transition (typically an instance of :class:`Prior`), which has to implement :meth:`Prior.marginal_likelihood`. :return: prior """ raise NotImplementedError
[docs] def inside(self, *args, **kwargs): """ Compute the inside probabilities and return the marginal data likelihood. :return: marginal likelihood """ # perform any initialisations self.init_inside(*args, **kwargs) # get the next batch of non-terminal locations (e.g. level in CYK) for non_term_loc in self.inside_schedule(): # go through all variables and associated cells (which contain the allowed transitions) at these locations for non_term_idx, cell in enumerate(self.cells()): # go through the possible transitions allowed for this non-terminal and collect inside marginals inside_marginals = [] for transition in cell.transitions(): inside_marginals.append( transition.inside_marginals(location=non_term_loc, inside_chart=self.inside_chart, terminal_chart=self.terminal_chart) ) # compute the mixture over inside marginals inside_mixture = cell.inside_mixture(inside_marginals) # update chart for variable(s) self.update_inside_chart(var_idx=non_term_idx, locations=non_term_loc, values=inside_mixture) # add the prior likelihood for the root location return self.prior.marginal_likelihood(root_location=self.root_location, inside_chart=self.inside_chart)
[docs] class Transition(ABC, torch.nn.Module): """ Base class for RBN transitions, which have to implement :meth:`~Transition.inside_marginals`. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] @abstractmethod def inside_marginals(self, location, inside_chart, terminal_chart, **kwargs): r""" Compute the marginals over inside probabilities .. math:: \widetilde{\beta}_{i:\ldots:k}(x_{i:k}) = {\int\cdots\int}_{\{v_{j:j^\prime}\in\mathcal{X}\}} p_{\tau}(v_{i:j_1}, \ldots, v_{j_{\eta-1}:k} \mid x_{i:k}) \prod_{\{v_{j:j^\prime}\in\mathcal{X}\}} \beta(v_{j:j^\prime}), for all possible splitting points (also see :ref:`here <general inside probabilities>`). In particular, ``location`` specifies the variable's location in the parse chart (the indices :math:`i` and :math:`k` in the equation above), from which the possible splitting points follow (:math:`\eta` splitting points :math:`j_1, \ldots,j_{\eta-1}` for transitions of arity :math:`\eta`). The marginals should always be returned in an array or iterable where the first dimension corresponds to all possible combinations of splitting points, even for transitions with arity :math:`\eta\neq2` (i.e. for :math:`\eta=1`, where there are no splits, the first dimension should be of size 1 and for :math:`\eta>1` all possible combinations of the :math:`\eta-1` splitting points should be listed in a flattened form in the first dimension). Additional dimensions, may be used to represent the dependency of the marginal on the variable :math:`x_{i:k}` (e.g. for a discrete variable, the second dimension may list the marginal for each possible value :math:`x_{i:k}` can take; and for a continuous variable, the marginal may be represented by a set of parameters). The output of this function is typically handled by a custom implementation of :meth:`Cell.inside_mixture() <rbnet.base.Cell.inside_mixture>`. :param location: location of the variable for which to compute the inside marginals :param inside_chart: a lookup chart with inside probabilities for other variables :param terminal_chart: a lookup chart with values of the terminal variables :return: array-like or iterable with inside probabilities """ raise NotImplementedError
[docs] class Prior(ABC, torch.nn.Module): """ Base class for prior transitions. Prior transitions are similar to normal Transition's, but cannot directly generate terminal variables and have no parent variables. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] @abstractmethod def marginal_likelihood(self, root_location, inside_chart, **kwargs): r""" Compute the marginal data likelihood .. math:: p(\mathbf{Y}) = \sum_{x\in\mathcal{X}} w_x \int \beta(x_{0:n}) \ p_P(x_{0:n}) \ dx_{0:n}, as described in more detail :ref:`here <marginal likelihood>`. :param root_location: location of the root variables :param inside_chart: chart of inside probabilities :return: marginal likelihood """ raise NotImplementedError
[docs] class NonTermVar(ABC): """ Base class for non-terminal variables. Instances of :class:`NonTermVar` represent a specific type of template variable. A parse chart for that variable for specific input data can be requested via :meth:`get_chart`. Mixtures over this variable type can be computed via :meth:`mixture`. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] @abstractmethod def get_chart(self, *args, **kwargs): """ Return a parse chart to store this variable type in. The specific arguments depend on the implementation but typically include information about the input data for which the parse chart is requested (e.g. the length of the sequence for sequential input data). """ raise NotImplementedError
[docs] @abstractmethod def mixture(self, *args, **kwargs): """ Compute a mixture over this variable type. The specific arguments depend on the implementation but typically include iterables over the mixture components and possibly their weights. """ raise NotImplementedError
[docs] class Cell(ABC, torch.nn.Module): """ Base class for RBN cells associated to a non-terminal template variable. Cells hold all transitions that are possible from that variable, accessible via :meth:`transitions`, and implement computation of the mixture of inside probabilities over that variable in :meth:`inside_mixture`. """ def __init__(self, variable, *args, **kwargs): """ A cell for a given non-terminal variable :param variable: non-terminal variable """ super().__init__(*args, **kwargs) self.variable = variable
[docs] @abstractmethod def transitions(self) -> Iterable[Transition]: """ Iterate through all possible transitions. """ raise NotImplementedError
[docs] @abstractmethod def inside_mixture(self, inside_marginals): r""" For a list of inside marginals :math:`\widetilde{\beta}_{i:\ldots:k}(x_{i:k})`, compute the mixture .. math:: \beta(x_{i:k}) = \sum_{\tau \in \mathcal{T}_x} p_S(z_{i:k}=\tau \mid x_{i:k}) {\sum\cdots\sum}_{i<j_1<\ldots<j_{\eta-1}<k} \ \widetilde{\beta}_{i:\ldots:k}(x_{i:k}) as described in more detail :ref:`here <general inside probabilities>`. The inside marginals are typically computed by a custom implementation of :meth:`Transition.inside_marginals() <rbnet.base.Transition.inside_marginals>` :param inside_marginals: iterable over inside marginals (as returned by :meth:`Transition.inside_marginals`) :return: representation of the inside probability :math:`\beta(x_{i:k})` """ raise NotImplementedError
[docs] class SequentialRBN(RBN, torch.nn.Module): def __init__(self, cells, prior, *args, **kwargs): super().__init__(*args, **kwargs) self._cells = ConstrainedModuleList(cells) self._prior = prior self.n = None self._terminal_chart = None self._inside_chart = None
[docs] def init_inside(self, sequence): self.n = len(sequence) self._terminal_chart = sequence self._inside_chart = [c.variable.get_chart(len(sequence)) for c in self._cells]
[docs] def inside_schedule(self, *args, **kwargs): for span in range(1, self.n + 1): for start in range(self.n - span + 1): yield start, start + span
@RBN.root_location.getter def root_location(self): return 0, self.n
[docs] def cells(self): return self._cells
[docs] def update_inside_chart(self, var_idx, locations, values): self._inside_chart[var_idx][locations] = values
@RBN.inside_chart.getter def inside_chart(self): return self._inside_chart @RBN.terminal_chart.getter def terminal_chart(self): return self._terminal_chart @RBN.prior.getter def prior(self): return self._prior
[docs] def main(): pass
if __name__ == "__main__": main()