Source code for diagnnose.extract.simple_extract

import shutil
from typing import TYPE_CHECKING, Optional, Tuple

from diagnnose.activations import ActivationReader
from diagnnose.activations.selection_funcs import return_all
from diagnnose.corpus import Corpus
from diagnnose.typedefs.activations import (
    ActivationNames,
    RemoveCallback,
    SelectionFunc,
)
from diagnnose.utils.misc import suppress_print

if TYPE_CHECKING:
    from diagnnose.models import LanguageModel

from .extractor import BATCH_SIZE, Extractor


[docs]@suppress_print def simple_extract( model: "LanguageModel", corpus: Corpus, activation_names: ActivationNames, activations_dir: Optional[str] = None, batch_size: int = BATCH_SIZE, selection_func: SelectionFunc = return_all, ) -> Tuple[ActivationReader, RemoveCallback]: """Basic extraction method. Parameters ---------- model : LanguageModel Language model that inherits from LanguageModel. corpus : Corpus Corpus containing sentences to be extracted. activation_names : List[tuple[int, str]] List of (layer, activation_name) tuples activations_dir : str, optional Directory to which activations will be written. If not provided the `extract()` method will only return the activations without writing them to disk. selection_func : SelectionFunc Function which determines if activations for a token should be extracted or not. batch_size : int, optional Amount of sentences processed per forward step. Higher batch size increases extraction speed, but should be done accordingly to the amount of available RAM. Defaults to 1. Returns ------- activation_reader : ActivationReader ActivationReader for the activations that have been extracted. remove_activations : RemoveCallback Callback function that can be executed at the end of a procedure that depends on the extracted activations. Removes all the activations that have been extracted. Takes no arguments. """ extractor = Extractor( model, corpus, activation_names, activations_dir=activations_dir, selection_func=selection_func, batch_size=batch_size, ) activation_reader = extractor.extract() def remove_activations(): if activations_dir is not None: shutil.rmtree(activations_dir) return activation_reader, remove_activations