Source code for diagnnose.syntax.task

import glob
import os
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd
import torch
from torch import Tensor
from transformers import PreTrainedTokenizer

from diagnnose.activations.selection_funcs import (
    final_token,
    no_special_tokens,
    only_mask_token,
    return_all,
)
from diagnnose.corpus import Corpus, create_iterator
from diagnnose.extract import simple_extract
from diagnnose.models import LanguageModel
from diagnnose.typedefs.activations import SelectionFunc
from diagnnose.typedefs.syntax import AccuracyDict, ScoresDict, SyntaxEvalCorpora


[docs]class SyntaxEvalTask: """Base class for syntactic evaluation tasks, from which specific tasks can inherit. Parameters ---------- model : LanguageModel Language model for which the accuracy is calculated. tokenizer : PreTrainedTokenizer The model tokenizer that converts tokens into indices. config : Dict[str, Any] Configuration dictionary containing the setup for task initialization. use_full_model_probs : bool, optional Toggle to calculate the full model probs for the NPI sentences. If set to False only the NPI logits will be compared, instead of their Softmax probabilities. Defaults to True. ignore_unk : bool, optional Ignore cases for which at least one of the cases of the verb is not part of the model's tokenizer. Defaults to False. """ def __init__( self, model: LanguageModel, tokenizer: PreTrainedTokenizer, ignore_unk: bool, use_full_model_probs: bool, **config: Dict[str, Any], ): model.eval() self.model = model self.tokenizer = tokenizer self.ignore_unk = ignore_unk self.use_full_model_probs = use_full_model_probs self.compare_full_sen = config.get("compare_full_sen", False) # If a single subtask is passed as cmd arg it is not converted to a list yet if isinstance(config.get("subtasks", None), str): config["subtasks"] = [config["subtasks"]] self.corpora: SyntaxEvalCorpora = self.initialize(**config)
[docs] def initialize( self, path: str, header: Optional[List[str]] = None ) -> SyntaxEvalCorpora: if header is None: header = ["sen", "token", "counter_token"] assert "sen" in header assert "token" in header assert "counter_sen" in header or "counter_token" in header corpora = {} if os.path.isdir(path): for file in glob.glob(os.path.join(path, "*")): corpus = Corpus.create(file, header=header, tokenizer=self.tokenizer) task_name = file.split("/")[-1].split(".")[0] corpora[task_name] = corpus elif os.path.isfile(path): corpus = Corpus.create(path, header=header, tokenizer=self.tokenizer) task_name = path.split("/")[-1].split(".")[0] corpora[task_name] = corpus else: raise FileNotFoundError("Path to task is not found") return corpora
[docs] def run(self) -> Tuple[AccuracyDict, ScoresDict]: """Performs the syntactic evaluation task that is initialised. Returns ------- results : ResultsDict Dictionary mapping a task to a task condition to the model accuracy. """ accuracies: AccuracyDict = {} scores: ScoresDict = {} for subtask, subtask_corpora in self.corpora.items(): if isinstance(subtask_corpora, Corpus): scores_df = self._run_corpus(subtask_corpora) scores[subtask] = scores_df accuracy: float = (scores_df.scores > scores_df.counter_scores).mean() accuracies[subtask] = accuracy else: for condition, corpus in subtask_corpora.items(): scores_df = self._run_corpus(corpus) scores.setdefault(subtask, {})[condition] = scores_df accuracy: float = ( scores_df.scores > scores_df.counter_scores ).mean() accuracies.setdefault(subtask, {})[condition] = accuracy return accuracies, scores
def _run_corpus(self, corpus: Corpus) -> pd.DataFrame: if self.compare_full_sen: selection_func = no_special_tokens(self.tokenizer) elif self.model.is_causal: selection_func = final_token("sen") else: selection_func = only_mask_token(self.tokenizer.mask_token, "sen") if self.ignore_unk: sen_ids = self._create_non_unk_sen_ids(corpus) corpus = corpus.slice(sen_ids) if len(corpus) == 0: return pd.DataFrame(columns=["scores", "counter_scores"]) activations = self._calc_final_hidden(corpus, selection_func) if "counter_sen" in corpus.fields: if self.compare_full_sen: counter_selection_func = no_special_tokens( self.tokenizer, sen_column="counter_sen" ) elif self.model.is_causal: counter_selection_func = final_token("counter_sen") else: counter_selection_func = only_mask_token( self.tokenizer.mask_token, "counter_sen" ) corpus.sen_column = "counter_sen" counter_activations = self._calc_final_hidden( corpus, counter_selection_func ) else: counter_activations = None if self.compare_full_sen: scores_df = self._calc_full_sen_scores( corpus, activations, counter_activations, selection_func, counter_selection_func, ) else: scores_df = self._calc_scores( corpus, activations, counter_activations=counter_activations, ) return scores_df def _create_non_unk_sen_ids(self, corpus: Corpus) -> List[int]: """ Creates a list of sen ids for which none of the items in that sentence are unknown to the tokenizer """ sen_ids = [] vocab = self.tokenizer.vocab # An unk token may neither appear in the prefix sen, nor be the eval token itself. for idx, ex in enumerate(corpus): if any(w not in vocab for w in ex.sen) or ex.token not in vocab: continue if hasattr(ex, "counter_token") and ex.counter_token not in vocab: continue if hasattr(ex, "counter_sen") and any( w not in vocab for w in ex.counter_sen ): continue sen_ids.append(idx) # skipped = len(corpus) - len(sen_ids) # if skipped: # warnings.warn(f"{skipped} out of {len(corpus)} items skipped") return sen_ids def _calc_final_hidden( self, corpus: Corpus, selection_func: SelectionFunc, ) -> Tensor: activation_name = (self.model.top_layer, "hx") activation_reader, _ = simple_extract( self.model, corpus, [activation_name], batch_size=len(corpus), selection_func=selection_func, ) if self.compare_full_sen: activations = activation_reader[:] else: activations = activation_reader.activation_dict[activation_name] return activations def _calc_full_sen_scores( self, corpus: Corpus, activations: Tensor, counter_activations: Tensor, selection_func: SelectionFunc, counter_selection_func: SelectionFunc, ) -> pd.DataFrame: scores_df = pd.DataFrame( { "sen": [ex.sen for ex in corpus], "counter_sen": [ex.counter_sen for ex in corpus], } ) scores = torch.zeros(len(corpus)) counter_scores = torch.zeros(len(corpus)) # The iterator tokenizes the sentences for us so we can index the probabilities with the sentence itself corpus_iterator = create_iterator( corpus, batch_size=1, device=self.model.device ) for idx, (activation, counter_activation, batch_item, corpus_item) in enumerate( zip(activations, counter_activations, corpus_iterator, corpus.examples) ): sen = batch_item.sen[0].squeeze() token_ids = [ token_idx for w_idx, token_idx in enumerate(sen) if selection_func(w_idx, corpus_item) ] all_logits = self._decode(activation).log_softmax(-1) logits = all_logits[range(len(token_ids)), token_ids] scores[idx] = logits.sum() counter_sen = batch_item.counter_sen[0].squeeze() counter_token_ids = [ token_idx for w_idx, token_idx in enumerate(counter_sen) if counter_selection_func(w_idx, corpus_item) ] all_logits = self._decode(counter_activation).log_softmax(-1) counter_logits = all_logits[ range(len(counter_token_ids)), counter_token_ids ] counter_scores[idx] = counter_logits.sum() scores_df["scores"] = scores scores_df["counter_scores"] = counter_scores return scores_df def _calc_scores( self, corpus: Corpus, activations: Tensor, counter_activations: Optional[Tensor] = None, ) -> pd.DataFrame: token_ids = torch.tensor( [self.tokenizer.convert_tokens_to_ids(ex.token) for ex in corpus] ) scores_df = pd.DataFrame( { "sen": [ex.sen for ex in corpus], "token": [ex.token for ex in corpus], } ) if counter_activations is None: scores_df["counter_token"] = [ex.counter_token for ex in corpus] counter_token_ids = torch.tensor( [ self.tokenizer.convert_tokens_to_ids(ex.counter_token) for ex in corpus ] ) scores, counter_scores = self._single_context_accuracy( activations, token_ids, counter_token_ids ) else: scores_df["counter_sen"] = [ex.counter_sen for ex in corpus] scores, counter_scores = self._dual_context_accuracy( activations, counter_activations, token_ids ) scores_df["scores"] = scores.detach() scores_df["counter_scores"] = counter_scores.detach() return scores_df def _single_context_accuracy( self, activations: Tensor, token_ids: Tensor, counter_token_ids: Tensor ) -> Tuple[Tensor, Tensor]: """ Computes accuracy for comparing P(w1|h) > P(w2|h). """ logits = self._decode(activations, token_ids) counter_logits = self._decode(activations, counter_token_ids) return logits, counter_logits def _dual_context_accuracy( self, activations: Tensor, counter_activations: Tensor, token_ids: Tensor, ) -> Tuple[Tensor, Tensor]: """ Computes accuracy for comparing P(w|h1) > P(w|h2). """ if self.use_full_model_probs: full_probs = self._decode(activations) counter_probs = self._decode(counter_activations) batch_size = full_probs.shape[0] probs = full_probs[range(batch_size), token_ids] counter_probs = counter_probs[range(batch_size), token_ids] else: probs = self._decode(activations, token_ids) counter_probs = self._decode(counter_activations, token_ids) return probs, counter_probs def _decode( self, activations: Tensor, token_ids: Optional[Tensor] = None ) -> Tensor: if hasattr(self.model, "decoder"): # Transformers with torch.no_grad(): logits = getattr(self.model, "decoder")(activations) if token_ids is not None: batch_size = logits.size(0) logits = logits[range(batch_size), token_ids] return logits elif hasattr(self.model, "decoder_w"): # LSTMs decoder_w = self.model.decoder_w decoder_b = self.model.decoder_b if token_ids is None: logits = activations @ decoder_w.t() + decoder_b return torch.nn.functional.log_softmax(logits, dim=-1) else: decoder_w = decoder_w[token_ids].unsqueeze(1) logits = torch.bmm(decoder_w, activations.unsqueeze(2)).squeeze() logits += decoder_b[token_ids] return logits else: raise AttributeError("Model decoder not found")