Source code for diagnnose.syntax.tasks.winobias

import os
from typing import Dict, List, Optional

from torchtext.data import RawField

from diagnnose.corpus import Corpus
from diagnnose.typedefs.syntax import SyntaxEvalCorpora

from ..task import SyntaxEvalTask


[docs]class WinobiasTask(SyntaxEvalTask):
[docs] def initialize( self, path: str, subtasks: Optional[List[str]] = None ) -> SyntaxEvalCorpora: """ Parameters ---------- path : str Path to directory containing the Marvin datasets that can be found in the github repo. subtasks : List[str], optional The downstream tasks that will be tested. If not provided this will default to the full set of conditions. Returns ------- corpora : Dict[str, Corpus] Dictionary mapping a subtask to a Corpus. """ subtasks = subtasks or ["stereo", "unamb"] corpora: SyntaxEvalCorpora = {} for subtask in subtasks: for condition in ["FF", "FM", "MF", "MM"]: corpus = Corpus.create( os.path.join(path, f"{subtask}_{condition}.tsv"), header_from_first_line=True, tokenizer=self.tokenizer, ) self._add_output_classes(corpus) corpora.setdefault(subtask, {})[condition] = corpus return corpora
@staticmethod def _add_output_classes(corpus: Corpus) -> None: """ Set the the pronouns for each sentence. """ corpus.fields["token"] = RawField() corpus.fields["counter_token"] = RawField() corpus.fields["token"].is_target = False corpus.fields["counter_token"].is_target = False for ex in corpus: setattr(ex, "token", "he") setattr(ex, "counter_token", "she")