Source code for diagnnose.models.wrappers.forward_lstm

import os
from typing import Dict

import torch
from torch import Tensor

from diagnnose.models.recurrent_lm import RecurrentLM


[docs]class ForwardLSTM(RecurrentLM): """Defines a default uni-directional n-layer LSTM. Allows for extraction of intermediate states and gate activations. Parameters ---------- state_dict : str Path to torch pickle containing the model parameter state dict. device : str, optional Torch device on which forward passes will be run. Defaults to cpu. rnn_name : str, optional Name of the rnn in the model state_dict. Defaults to `rnn`. encoder_name : str, optional Name of the embedding encoder in the model state_dict. Defaults to `encoder`. decoder_name : str, optional Name of the linear decoder in the model state_dict. Defaults to `decoder`. """ ih_concat_order = ["h", "i"] split_order = ["i", "f", "g", "o"] def __init__( self, state_dict: str, device: str = "cpu", rnn_name: str = "rnn", encoder_name: str = "encoder", decoder_name: str = "decoder", ) -> None: super().__init__(device) print("Loading pretrained model...") with open(os.path.expanduser(state_dict), "rb") as mf: params: Dict[str, Tensor] = torch.load(mf, map_location=device) params = params["state_dict"] if "state_dict" in params else params self._set_lstm_weights(params, rnn_name) # Encoder and decoder weights self.word_embeddings: Tensor = params[f"{encoder_name}.weight"] self.decoder_w: Tensor = params[f"{decoder_name}.weight"] self.decoder_b: Tensor = params[f"{decoder_name}.bias"] self.sizes[self.top_layer, "out"] = self.decoder_b.shape[0] print("Model initialisation finished.") def _set_lstm_weights(self, params: Dict[str, Tensor], rnn_name: str) -> None: # LSTM weights layer = 0 # 1 layer RNNs do not have a layer suffix in the state_dict no_suffix = self.param_names(0, rnn_name, no_suffix=True)["weight_hh"] in params assert ( self.param_names(0, rnn_name, no_suffix=no_suffix)["weight_hh"] in params ), "rnn weight name not found, check if setup is correct" while ( self.param_names(layer, rnn_name, no_suffix=no_suffix)["weight_hh"] in params ): param_names = self.param_names(layer, rnn_name, no_suffix=no_suffix) w_h = params[param_names["weight_hh"]] w_i = params[param_names["weight_ih"]] # Shape: (emb_size+nhid_h, 4*nhid_c) self.weight[layer] = torch.cat((w_h, w_i), dim=1).t() if param_names["bias_hh"] in params: # Shape: (4*nhid_c,) self.bias[layer] = ( params[param_names["bias_hh"]] + params[param_names["bias_ih"]] ) self.sizes.update( { (layer, "emb"): w_i.size(1), (layer, "hx"): w_h.size(1), (layer, "cx"): w_h.size(1), } ) layer += 1 if no_suffix: break
[docs] @staticmethod def param_names( layer: int, rnn_name: str, no_suffix: bool = False ) -> Dict[str, str]: """Creates a dictionary of parameter names in a state_dict. Parameters ---------- layer : int Current layer index. rnn_name : str Name of the rnn in the model state_dict. Defaults to `rnn`. no_suffix : bool, optional Toggle to omit the `_l{layer}` suffix from a parameter name. 1-layer RNNs do not have this suffix. Defaults to False. Returns ------- param_names : Dict[str, str] Dictionary mapping a general parameter name to the model specific parameter name. """ if no_suffix: return { "weight_hh": f"{rnn_name}.weight_hh", "weight_ih": f"{rnn_name}.weight_ih", "bias_hh": f"{rnn_name}.bias_hh", "bias_ih": f"{rnn_name}.bias_ih", } return { "weight_hh": f"{rnn_name}.weight_hh_l{layer}", "weight_ih": f"{rnn_name}.weight_ih_l{layer}", "bias_hh": f"{rnn_name}.bias_hh_l{layer}", "bias_ih": f"{rnn_name}.bias_ih_l{layer}", }