Source code for diagnnose.models.wrappers.awd_lstm

from typing import Any, Dict

from .forward_lstm import ForwardLSTM


[docs]class AWDLSTM(ForwardLSTM): def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("rnn_name", "rnns") super().__init__(*args, **kwargs)
[docs] @staticmethod def param_names( layer: int, rnn_name: str, no_suffix: bool = False, **kwargs ) -> Dict[str, str]: # The AWD-LSTM has no separate weight names for a single layer LSTM if no_suffix: return { "weight_hh": "", "weight_ih": "", "bias_hh": "", "bias_ih": "", } else: return { "weight_hh": f"{rnn_name}.{layer}.module.weight_hh_l0_raw", "weight_ih": f"{rnn_name}.{layer}.module.weight_ih_l0", "bias_hh": f"{rnn_name}.{layer}.module.bias_hh_l0", "bias_ih": f"{rnn_name}.{layer}.module.bias_ih_l0", }