Source code for diagnnose.attribute.explainer

from typing import List, Tuple, Union

from torch import Tensor
from transformers import BatchEncoding, PreTrainedTokenizer

from .decomposer import Decomposer


[docs]class Explainer: """ Generates an explanation for a specific input. """ def __init__( self, decomposer: Decomposer, tokenizer: PreTrainedTokenizer, ): self.decomposer = decomposer self.tokenizer = tokenizer
[docs] def explain(self, input_tokens: Union[str, List[str]], output_tokens: List[str]): batch_encoding = self._tokenize(input_tokens) out, contributions = self.decomposer.decompose(batch_encoding) output_ids, mask_ids = self._create_output_ids(batch_encoding, output_tokens) full_probs = self._fetch_token_probs(out, output_ids, mask_ids) contribution_probs = [ self._fetch_token_probs(contribution, output_ids, mask_ids) for contribution in contributions ] return full_probs, contribution_probs
def _tokenize(self, input_tokens: Union[str, List[str]]) -> BatchEncoding: input_tokens = [input_tokens] if isinstance(input_tokens, str) else input_tokens batch_encoding = self.tokenizer( input_tokens, padding=True, return_attention_mask=False, return_length=True, return_token_type_ids=False, ) return batch_encoding def _create_output_ids( self, batch_encoding: BatchEncoding, output_tokens: List[str] ) -> Tuple[List[int], List[int]]: mask_idx = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) unk_idx = self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token) mask_idx = None if mask_idx == unk_idx else mask_idx mask_ids = [ sen.index(mask_idx) if mask_idx in sen else final_idx - 1 for sen, final_idx in zip( batch_encoding["input_ids"], batch_encoding["length"] ) ] output_ids = [] mask_token = self.tokenizer.mask_token or "<mask>" for token in output_tokens: sub_token_id = self.tokenizer.convert_tokens_to_ids(token) if sub_token_id == unk_idx: # Simply encoding "token" often yields a different index than when it is embedded # within a sentence, hence the ugly hack here. sub_token_ids = self.tokenizer.encode( f"{mask_token} {token}", add_special_tokens=False )[1:] sub_tokens = self.tokenizer.convert_ids_to_tokens(sub_token_ids) assert ( len(sub_token_ids) == 1 ), f"Multi-subword tokens not supported ({token} -> {str(sub_tokens)})" sub_token_id = sub_token_ids[0] output_ids.append(sub_token_id) return output_ids, mask_ids @staticmethod def _fetch_token_probs( probs: Tensor, output_ids: List[int], mask_ids: List[int] ) -> Tensor: """Fetches the probability of each output class at the position of the corresponding mask_idx. Parameters ---------- probs : Tensor Tensor with output probabilities of shape: batch_size x max_sen_len x output_dim. output_ids : List[int] List of indices of the output classes that are decomposed. mask_ids : List[int] List of indices that signify the position of each sentence in the input batch where the decomposition will take place. Returns ------- token_probs : Tensor Tensor containing the corresponding probabilities. """ mask_probs = probs[range(probs.size(0)), mask_ids] token_probs = mask_probs[:, output_ids] return token_probs
[docs] def print_attributions( self, full_probs: Tensor, contribution_probs: List[Tensor], input_tokens: Union[str, List[str]], output_tokens: List[str], ): batch_encoding = self._tokenize(input_tokens) for sen_idx, token_ids in enumerate(batch_encoding["input_ids"]): print((" " * 15) + "".join(f"{w:<15}" for w in output_tokens)) print( f"{'Full logits':<15}" + "".join(f"{p:<15.3f}" for p in full_probs[sen_idx]) ) print("-" * 15 * (len(output_tokens) + 1)) sen_features = [self.tokenizer.decode([w]) for w in token_ids] sen_len = batch_encoding["length"][sen_idx] features = ["model_bias", *sen_features[:sen_len]] for i, feature in enumerate(features): print( f"{feature:<15}" + "".join(f"{p:<15.3f}" for p in contribution_probs[i][sen_idx]) ) print("\n")