Source code for diagnnose.extract.extractor
from contextlib import ExitStack
from typing import TYPE_CHECKING, Optional, Union
import torch
from torchtext.data import Batch
from tqdm import tqdm
import diagnnose.activations.selection_funcs as selection_funcs
from diagnnose.activations import ActivationReader, ActivationWriter
from diagnnose.activations.selection_funcs import return_all
from diagnnose.corpus import Corpus, create_iterator
from diagnnose.typedefs.activations import (
ActivationDict,
ActivationNames,
ActivationRanges,
SelectionFunc,
)
if TYPE_CHECKING:
from diagnnose.models import LanguageModel
BATCH_SIZE = 1024
[docs]class Extractor:
"""Extracts all intermediate activations of a LM from a corpus.
Only activations that are provided in activation_names will be
stored in a pickle file. Each activation is written to its own file.
Parameters
----------
model : LanguageModel
Language model that inherits from LanguageModel.
corpus : Corpus
Corpus containing sentences to be extracted.
activation_names : List[tuple[int, str]], optional
List of (layer, activation_name) tuples. If not provided all
activation_names corresponding to the ``model`` will be
extracted.
activations_dir : str, optional
Directory to which activations will be written. If not provided
the `extract()` method will only return the activations without
writing them to disk.
selection_func : Union[SelectionFunc, str]
Function which determines if activations for a token should
be extracted or not. Can also be provided as a string,
indicating the method name of one of the default
selection_funcs in
:py:mod:`diagnnose.activations.selection_funcs`.
batch_size : int, optional
Amount of sentences processed per forward step. Higher batch
size increases extraction speed, but should be done
accordingly to the amount of available RAM. Defaults to 1.
"""
def __init__(
self,
model: "LanguageModel",
corpus: Corpus,
activation_names: Optional[ActivationNames] = None,
activations_dir: Optional[str] = None,
selection_func: Union[SelectionFunc, str] = return_all,
batch_size: int = BATCH_SIZE,
) -> None:
self.model = model
self.corpus = corpus
self.activation_names = activation_names or model.activation_names()
if isinstance(selection_func, str):
self.selection_func: SelectionFunc = getattr(
selection_funcs, selection_func
)
else:
self.selection_func: SelectionFunc = selection_func
self.batch_size = batch_size
self.activation_ranges = []
self.set_activation_ranges()
if activations_dir is None:
self.activation_writer: Optional[ActivationWriter] = None
else:
self.activation_writer = ActivationWriter(activations_dir)
[docs] def extract(self) -> ActivationReader:
"""Extracts embeddings from a corpus.
Uses :class:`contextlib.ExitStack` to write to multiple files
simultaneously.
Returns
-------
activation_reader : ActivationReader
After extraction an activation_reader is returned that
provides direct access to the extracted activations.
"""
print(f"\nStarting extraction of {len(self.corpus)} sentences...")
if self.activation_writer is not None:
with ExitStack() as stack:
self.activation_writer.create_output_files(stack, self.activation_names)
self._extract_corpus(dump=True)
self.activation_writer.dump_meta_info(
self.activation_ranges, self.selection_func
)
activation_reader = ActivationReader(
activations_dir=self.activation_writer.activations_dir,
activation_names=self.activation_names,
)
else:
corpus_activations = self._extract_corpus(dump=False)
activation_reader = ActivationReader(
activation_dict=corpus_activations,
activation_names=self.activation_names,
activation_ranges=self.activation_ranges,
selection_func=self.selection_func,
)
n_extracted = self.activation_ranges[-1][-1]
print(f"Extraction finished, {n_extracted} activations have been extracted.")
return activation_reader
def _extract_corpus(self, dump: bool = True) -> ActivationDict:
tot_extracted = self.activation_ranges[-1][1]
corpus_activations: ActivationDict = self._init_activation_dict(
tot_extracted, dump=dump
)
corpus = self._filter_corpus()
iterator = create_iterator(
corpus, batch_size=self.batch_size, device=self.model.device
)
for batch in tqdm(iterator, unit="batch"):
batch_activations = self._extract_batch(batch)
if dump:
self.activation_writer.dump_activations(batch_activations)
else:
# Insert extracted batch activations into full corpus activations dict.
batch_start = self.activation_ranges[batch.sen_idx[0]][0]
batch_stop = self.activation_ranges[batch.sen_idx[-1]][1]
for a_name, activations in batch_activations.items():
corpus_activations[a_name][batch_start:batch_stop] = activations
return corpus_activations
def _filter_corpus(self) -> Corpus:
""" Skip items for which selection_func yields 0 activations. """
sen_ids = [
idx
for idx, (start, stop) in enumerate(self.activation_ranges)
if start != stop
]
if len(sen_ids) != len(self.corpus):
return self.corpus.slice(sen_ids)
return self.corpus
def _extract_batch(self, batch: Batch) -> ActivationDict:
"""Processes the items in `batch` and selects the activations
that should should be extracted according to selection_func.
"""
sens, sen_lens = getattr(batch, self.corpus.sen_column)
compute_out = any("out" in a_name for a_name in self.activation_names)
kwargs = {}
if getattr(self.model, "compute_pseudo_ll", False):
kwargs["mask_idx"] = self.corpus.tokenizer.mask_token_id
kwargs["selection_func"] = self.selection_func
kwargs["batch"] = batch
with torch.no_grad():
# a_name -> batch_size x max_sen_len x nhid
all_activations: ActivationDict = self.model(
input_ids=sens,
input_lengths=sen_lens,
compute_out=compute_out,
only_return_top_embs=False,
**kwargs,
)
# a_name -> n_items_in_batch x nhid
batch_activations = self._select_activations(all_activations, batch)
return batch_activations
def _select_activations(
self,
all_activations: ActivationDict,
batch: Batch,
) -> ActivationDict:
""" Selects only the activations that pass selection_func. """
batch_start = self.activation_ranges[batch.sen_idx[0]][0]
batch_end = self.activation_ranges[batch.sen_idx[-1]][1]
n_items_in_batch = batch_end - batch_start
batch_activations: ActivationDict = self._init_activation_dict(n_items_in_batch)
a_idx = 0
for b_idx, sen_idx in enumerate(batch.sen_idx):
item = self.corpus[sen_idx]
sen_len = len(getattr(item, self.corpus.sen_column))
for w_idx in range(sen_len):
if self.selection_func(w_idx, item):
for a_name in batch_activations:
selected_activation = all_activations[a_name][b_idx][w_idx]
batch_activations[a_name][a_idx] = selected_activation
a_idx += 1
return batch_activations
[docs] def set_activation_ranges(self) -> None:
activation_ranges: ActivationRanges = []
tot_extracted = 0
for item in self.corpus:
start = tot_extracted
sen_len = len(getattr(item, self.corpus.sen_column))
for w_idx in range(sen_len):
if self.selection_func(w_idx, item):
tot_extracted += 1
activation_ranges.append((start, tot_extracted))
self.activation_ranges = activation_ranges
def _init_activation_dict(self, n_items: int, dump: bool = False) -> ActivationDict:
# If activations are dumped we don't keep track of the full activation dictionary,
# so an empty dictionary is returned.
if dump:
return {}
corpus_activations = {
a_name: torch.zeros(
n_items, self.model.nhid(a_name), device=self.model.device
)
for a_name in self.activation_names
}
return corpus_activations