import os
import sys
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from diagnnose.models.recurrent_lm import RecurrentLM
from diagnnose.tokenizer import create_char_vocab
from diagnnose.tokenizer.c2i import C2I
[docs]class GoogleLM(RecurrentLM):
"""Reimplementation of the LM of Jozefowicz et al. (2016).
Paper: https://arxiv.org/abs/1602.02410
Lib: https://github.com/tensorflow/models/tree/master/research/lm_1b
This implementation allows for only a subset of the SoftMax to be
loaded in, to alleviate RAM usage.
Parameters
----------
ckpt_dir : str
Path to folder containing parameter checkpoint files.
corpus_vocab_path : str, optional
Path to the corpus for which a vocabulary will be created. This
allows for only a subset of the model softmax to be loaded in.
create_decoder : bool
Toggle to load in the (partial) softmax weights. Can be set to
false in case no decoding projection needs to be made, as is
the case during activation extraction, for example.
"""
sizes = {
(layer, name): size
for layer in range(2)
for name, size in [("emb", 1024), ("hx", 1024), ("cx", 8192)]
}
forget_offset = 1
ih_concat_order = ["i", "h"]
split_order = ["i", "g", "f", "o"]
use_char_embs = True
use_peepholes = True
def __init__(
self,
ckpt_dir: str,
corpus_vocab_path: Optional[Union[str, List[str]]] = None,
create_decoder: bool = True,
device: str = "cpu",
) -> None:
super().__init__(device)
print("Loading pretrained model...")
full_vocab_path = os.path.join(ckpt_dir, "vocab-2016-09-10.txt")
vocab: C2I = create_char_vocab(
corpus_vocab_path or full_vocab_path, unk_token="<UNK>"
)
self.encoder = CharCNN(ckpt_dir, vocab, device)
self._set_lstm_weights(ckpt_dir, device)
if create_decoder:
self.decoder = SoftMax(
vocab, full_vocab_path, ckpt_dir, self.sizes[1, "hx"], device
)
self.decoder_w = self.decoder.decoder_w
self.decoder_b = self.decoder.decoder_b
print("Model initialisation finished.")
[docs] def decode(self, hidden_state: Tensor) -> Tensor:
return self.decoder_w @ hidden_state + self.decoder_b
def _set_lstm_weights(self, ckpt_dir: str, device: str) -> None:
from tensorflow.compat.v1.train import NewCheckpointReader
print("Loading LSTM...")
lstm_reader = NewCheckpointReader(os.path.join(ckpt_dir, "ckpt-lstm"))
for l in range(self.num_layers):
# Model weights are divided into 8 chunks
# Shape: (2048, 32768)
self.weight[l] = torch.cat(
[
torch.from_numpy(lstm_reader.get_tensor(f"lstm/lstm_{l}/W_{i}"))
for i in range(8)
],
dim=0,
)
# Shape: (32768,)
self.bias[l] = torch.from_numpy(lstm_reader.get_tensor(f"lstm/lstm_{l}/B"))
# Shape: (8192, 1024)
self.weight_P[l] = torch.cat(
[
torch.from_numpy(lstm_reader.get_tensor(f"lstm/lstm_{l}/W_P_{i}"))
for i in range(8)
],
dim=0,
)
for p in ["f", "i", "o"]:
# Shape: (8192, 8192)
self.peepholes[l, p] = torch.from_numpy(
lstm_reader.get_tensor(f"lstm/lstm_{l}/W_{p.upper()}_diag")
)
# Cast to float32 tensors
self.weight[l] = self.weight[l].to(device)
self.weight_P[l] = self.weight_P[l].to(device)
self.bias[l] = self.bias[l].to(device)
for p in ["f", "i", "o"]:
self.peepholes[l, p] = self.peepholes[l, p].to(device)
[docs]class CharCNN(nn.Module):
def __init__(self, ckpt_dir: str, vocab: C2I, device: str) -> None:
print("Loading CharCNN...")
super().__init__()
self.cnn_sess, self.cnn_t = self._load_char_cnn(ckpt_dir)
self.cnn_embs: Dict[str, Tensor] = {}
self.vocab = vocab
self.device = device
@staticmethod
def _load_char_cnn(ckpt_dir: str) -> Any:
import tensorflow as tf
from google.protobuf import text_format
pbtxt_file = os.path.join(ckpt_dir, "graph-2016-09-10.pbtxt")
ckpt_file = os.path.join(ckpt_dir, "ckpt-char-embedding")
with tf.compat.v1.Graph().as_default():
sys.stderr.write("Recovering graph.\n")
with tf.compat.v1.gfile.FastGFile(pbtxt_file, "r") as f:
s = f.read()
gd = tf.compat.v1.GraphDef()
text_format.Merge(s, gd)
t = dict()
[t["char_inputs_in"], t["all_embs"]] = tf.import_graph_def(
gd, {}, ["char_inputs_in:0", "all_embs_out:0"], name=""
)
sess = tf.compat.v1.Session(
config=tf.compat.v1.ConfigProto(allow_soft_placement=True)
)
sess.run(f"save/Assign", {"save/Const:0": ckpt_file})
# The following was recovered from the graph structure, the first 62 assign modules
# relate to the parameters of the char CNN.
for i in range(1, 62):
sess.run(f"save/Assign_{i}", {"save/Const:0": ckpt_file})
return sess, t
[docs] def forward(self, input_ids: Tensor) -> Tensor:
"""Fetches the character-CNN embeddings of a batch
Parameters
----------
input_ids : (batch_size, max_sen_len)
Returns
-------
inputs_embeds : (batch_size, max_sen_len, emb_dim)
"""
inputs_embeds = torch.zeros((*input_ids.shape, 1024), device=self.device)
for batch_idx in range(input_ids.shape[0]):
tokens: List[str] = [
self.vocab.i2w[token_idx] for token_idx in input_ids[batch_idx]
]
for i, token in enumerate(tokens):
if token not in self.cnn_embs:
char_ids = self.vocab.token_to_char_ids(token)
input_dict = {self.cnn_t["char_inputs_in"]: char_ids}
emb = torch.from_numpy(
self.cnn_sess.run(self.cnn_t["all_embs"], input_dict)
).to(self.device)
self.cnn_embs[token] = emb
else:
emb = self.cnn_embs[token]
inputs_embeds[batch_idx, i] = emb
return inputs_embeds
[docs]class SoftMax:
def __init__(
self,
vocab: C2I,
full_vocab_path: str,
ckpt_dir: str,
hidden_size_h: int,
device: str,
) -> None:
print("Loading SoftMax...")
self.decoder_w: Tensor = torch.zeros((len(vocab), hidden_size_h), device=device)
self.decoder_b: Tensor = torch.zeros(len(vocab), device=device)
self._load_softmax(vocab, full_vocab_path, ckpt_dir)
def _load_softmax(self, vocab: C2I, full_vocab_path: str, ckpt_dir: str) -> None:
from tensorflow.compat.v1.train import NewCheckpointReader
with open(full_vocab_path, encoding="ISO-8859-1") as f:
full_vocab: List[str] = [token.strip() for token in f]
bias_reader = NewCheckpointReader(os.path.join(ckpt_dir, "ckpt-softmax8"))
full_bias = torch.from_numpy(bias_reader.get_tensor("softmax/b"))
# SoftMax is chunked into 8 arrays of size 100000x1024
for i in range(8):
sm_reader = NewCheckpointReader(os.path.join(ckpt_dir, f"ckpt-softmax{i}"))
sm_chunk = torch.from_numpy(sm_reader.get_tensor(f"softmax/W_{i}"))
bias_chunk = full_bias[i : full_bias.shape[0] : 8]
vocab_chunk = full_vocab[i : full_bias.shape[0] : 8]
for j, w in enumerate(vocab_chunk):
sm = sm_chunk[j]
bias = bias_chunk[j]
if w in vocab and vocab[w] < self.decoder_w.shape[0]:
self.decoder_w[vocab[w]] = sm
self.decoder_b[vocab[w]] = bias