Source code for diagnnose.syntax.evaluator

from typing import Any, Dict, List, Optional, Tuple, Type

from transformers import PreTrainedTokenizer

from diagnnose.models import LanguageModel
from diagnnose.typedefs.syntax import AccuracyDict, ScoresDict

from .task import SyntaxEvalTask
from .tasks import (
    BlimpTask,
    LakretzTask,
    LinzenTask,
    MarvinTask,
    WarstadtTask,
    WinobiasTask,
)

task_constructors: Dict[str, Type[SyntaxEvalTask]] = {
    "blimp": BlimpTask,
    "lakretz": LakretzTask,
    "linzen": LinzenTask,
    "marvin": MarvinTask,
    "warstadt": WarstadtTask,
    "winobias": WinobiasTask,
}


[docs]class SyntacticEvaluator: """Suite that runs multiple syntactic evaluation tasks on a LM. Tasks can be run on already extracted activations, or on a new LM for which new activations will be extracted. Initialisation is performed separately from the tasks themselves, in order to allow multiple LMs to be ran on the same set of tasks. Parameters ---------- tokenizer : PreTrainedTokenizer Tokenizer that converts tokens to their index within the LM. config : Dict[str, Any] Dictionary mapping a task name (`lakretz`, `linzen`, `marvin`, `warstadt`, or `winobias`) to its configuration (`path`, `tasks`, and `task_activations`). `path` points to the corpus folder of the task, `tasks` is an optional list of subtasks, and `task_activations` an optional path to the folder containing the model activations. """ def __init__( self, model: LanguageModel, tokenizer: PreTrainedTokenizer, config: Dict[str, Any], ignore_unk: bool = False, use_full_model_probs: bool = True, tasks: Optional[List[str]] = None, ) -> None: self.tasks: Dict[str, SyntaxEvalTask] = {} print("Initializing syntactic evaluation tasks...") for task_name in tasks or config.keys(): print(task_name) constructor = task_constructors.get(task_name, SyntaxEvalTask) self.tasks[task_name] = constructor( model, tokenizer, ignore_unk=ignore_unk, use_full_model_probs=use_full_model_probs, **config[task_name], ) print("Syntactic evaluation task initialization finished")
[docs] def run(self) -> Tuple[Dict[str, AccuracyDict], Dict[str, ScoresDict]]: accuracies: Dict[str, AccuracyDict] = {} scores: Dict[str, ScoresDict] = {} for task_name, task in self.tasks.items(): print(f"\n--=={task_name.upper()}==--") accuracies[task_name], scores[task_name] = task.run() return accuracies, scores