Custom type definitions

The library is build using the Python typing system. Based on this type system we define the following custom type definitions.

Activations

from typing import BinaryIO, Callable, Dict, List, Tuple, Union

from torch import Tensor
from torchtext.data import Example

# ACTIVATION DICTS
ActivationName = Tuple[int, str]  # (layer, name)
ActivationNames = List[ActivationName]

ActivationDict = Dict[ActivationName, Tensor]

# LM's layer sizes: (layer, name) -> size
SizeDict = Dict[ActivationName, int]


# EXTRACTION
ActivationFiles = Dict[ActivationName, BinaryIO]

# token index, corpus item -> bool
SelectionFunc = Callable[[int, Example], bool]

# [(start, stop)]
ActivationRanges = List[Tuple[int, int]]

RemoveCallback = Callable[[], None]


# INDEXING
# Activation indexing, as done in ActivationReader
ActivationIndex = Union[int, slice, List[int], Tensor]

ActivationKey = Union[ActivationIndex, Tuple[ActivationIndex, ActivationName]]

Probing

from collections import namedtuple
from typing import Callable, Union

from torchtext.data import Example

DataDict = namedtuple(
    "DataDict",
    [
        "train_activations",
        "train_labels",
        "train_control_labels",
        "test_activations",
        "test_labels",
        "test_control_labels",
    ],
)

DataSplit = namedtuple("DataSplit", ["activation_reader", "labels", "control_labels"])

DCConfig = namedtuple(
    "DCConfig",
    ["lr", "max_epochs", "rank", "lambda1", "verbose"],
)

# https://www.aclweb.org/anthology/D19-1275/
# w position, batch item -> label
ControlTask = Callable[[int, Example], Union[str, int]]