Source code for rbnet.sequential

import torch

from rbnet.base import RBN, Transition
from rbnet.util import ConstrainedModuleList


[docs] class SequentialRBN(RBN, torch.nn.Module): def __init__(self, cells, prior, *args, **kwargs): super().__init__(*args, **kwargs) self._cells = ConstrainedModuleList(cells) self._prior = prior self.n = None self._terminal_chart = None self._inside_chart = None self._outside_chart = None
[docs] def init_inside(self, sequence): self.n = len(sequence) self._terminal_chart = sequence self._inside_chart = [c.variable.get_chart(self.n) for c in self._cells]
[docs] def inside_schedule(self, *args, **kwargs): for span in range(1, self.n + 1): for start in range(self.n - span + 1): yield start, start + span
@RBN.root_location.getter def root_location(self): return 0, self.n
[docs] def cells(self): return self._cells
[docs] def update_inside_chart(self, var_idx, locations, values): self._inside_chart[var_idx][locations] = values
@RBN.inside_chart.getter def inside_chart(self): return self._inside_chart @RBN.terminal_chart.getter def terminal_chart(self): return self._terminal_chart @RBN.prior.getter def prior(self): return self._prior
[docs] class SequentialBinaryTransition(Transition):
[docs] def iterate_inside_splits(self, location): start, end = location if end - start <= 1: # no splitting possible return else: for split in range(start + 1, end): yield start, split, end
[docs] class SequentialTerminalTransition(Transition):
[docs] def iterate_inside_splits(self, location): start, end = location if end - start > 1: # no terminal transition possible return else: yield start