Source code for diagnnose.syntax.tasks.warstadt

from typing import Dict, List, Optional

from torchtext.data import Example

from diagnnose.corpus import Corpus
from diagnnose.typedefs.syntax import SyntaxEvalCorpora

from ..task import SyntaxEvalTask
from .warstadt_preproc import ENVS, create_downstream_corpus, preproc_warstadt


[docs]class WarstadtTask(SyntaxEvalTask):
[docs] def initialize( self, path: str, subtasks: Optional[List[str]] = None ) -> SyntaxEvalCorpora: """Performs the initialization for the tasks of Marvin & Linzen (2018) Arxiv link: https://arxiv.org/pdf/1808.09031.pdf Repo: https://github.com/BeckyMarvin/LM_syneval Parameters ---------- path : str Path to directory containing the Marvin datasets that can be found in the github repo. subtasks : List[str], optional The downstream tasks that will be tested. If not provided this will default to the full set of conditions. Returns ------- corpora : Dict[str, Corpus] Dictionary mapping a subtask to a Corpus. """ subtasks: List[str] = subtasks or ENVS corpora: SyntaxEvalCorpora = {} orig_corpus = preproc_warstadt(path) for env in subtasks: raw_corpus = create_downstream_corpus(orig_corpus, envs=[env]) header = raw_corpus[0].split("\t") tokenize_columns = ["sen", "counter_sen"] fields = Corpus.create_fields( header, tokenize_columns=tokenize_columns, tokenizer=self.tokenizer ) examples = [ Example.fromlist(line.split("\t"), fields) for line in raw_corpus[1:] ] corpus = Corpus(examples, fields) corpora[env] = corpus return corpora