Source code for diagnnose.corpus.create_labels

from typing import Optional

import torch
from torch import Tensor

from diagnnose.corpus import Corpus
from diagnnose.typedefs.activations import SelectionFunc
from diagnnose.typedefs.probe import ControlTask


[docs]def create_labels_from_corpus( corpus: Corpus, selection_func: SelectionFunc = lambda sen_id, pos, example: True, control_task: Optional[ControlTask] = None, ) -> Tensor: """Creates labels based on the selection_func that was used during extraction. Parameters ---------- corpus : Corpus Labeled corpus containing sentence and label information. selection_func: SelectFunc, optional Function that determines whether a label should be stored. control_task: ControlTask, optional Control task function of Hewitt et al. (2019), mapping a corpus item to a random label. """ all_labels = [] for item in corpus.examples: label_idx = 0 sen = getattr(item, corpus.sen_column) labels = getattr(item, corpus.labels_column) if isinstance(labels, str): labels = labels.split() each_token_labeled = len(sen) == len(labels) for wpos in range(len(sen)): if selection_func(wpos, item): if control_task is not None: label = control_task(wpos, item) else: label = labels[label_idx] all_labels.append(label) if not each_token_labeled and len(labels) > 1: label_idx += 1 if each_token_labeled: label_idx += 1 # Create new label vocab that only contains the labels that have been selected label_vocab = {label: idx for idx, label in enumerate(set(all_labels))} corpus.fields[corpus.labels_column].vocab = label_vocab return torch.tensor([label_vocab[label] for label in all_labels])