Bengio et al model.

Home   »   Bengio et al model.

from collections import Counter

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from nltk import word_tokenize


class BengioModel(keras.Model):
    ''' Model that replicates the architecture of Bengio et al.  '''
    def __init__(self, window_size: int, vocabulary_size: int, embedding_size: int=60, hidden_units: int=50, regulariser_l=0.1):
        ''' Initialise model. 
        
        Args:
        - window_size :: Number of words used for context.
        - vocabulary_size :: Size of the vocabulary in the corpus. 
        - embedding_size :: Size of the embedding layer.
        - hidden_units :: Number of hidden units in the hidden layer.
        - regulariser_l :: How strong regularisation is (As l -> inf, regularisation gets arbitrarily strong and smooths parameters). 
          NOTE: The default value of 0.1 is *just* a placeholder, as the paper didn't specify strength.
        
        '''
        super().__init__()
        self.window_size = window_size
        self.vocabulary_size = vocabulary_size
        self.embedding_size = embedding_size
        # Takes the place of tanh(d + Hx)
        # You could easily chuck a few more layers here if you wanted to experiment with depth.
        # Not sure why the original paper uses the tanh function (legacy????). I would recommend substituting this with a relu.
        self.non_linear = layers.Dense(hidden_units, activation=tf.nn.tanh)
        # NOTE: Paper didn't specify if the embedding is regularised????
        self.embedding = layers.Embedding(vocabulary_size, embedding_size)
        self.W = layers.Dense(vocabulary_size, use_bias=False, kernel_regularizer=keras.regularizers.l2(l=regulariser_l))
        self.U = layers.Dense(vocabulary_size, use_bias=False, kernel_regularizer=keras.regularizers.l2(l=regulariser_l))
        self.b = tf.Variable(tf.random.uniform((vocabulary_size,), minval=-1, maxval=1))

    def call(self, inputs):
        embed = self.embedding(inputs)
        # The embedding output will be a tensor of shape (batch_size, self.window_size, self.embedding_size), i.e one embedding per word in the window
        # This reshape call concatenates all of the embeddings together.
        embed = tf.reshape(embed, (-1, self.embedding_size * self.window_size))
        act = self.non_linear(embed)
        non_linear = self.U(act)
        linear = self.W(embed)
        logit = linear + non_linear + self.b
        return logit


def window(list, n: int, pad):
    ''' Produce a rolling window over a list of length n (using pad when if we run out of elements). '''
    for i in range(len(list) - n - 1):
        yield list[i : i + n], (pad if i == n - 1 else list[i + n + 1])

def load_data(filename: str, window_size: int):
    ''' This code is almost identical to what Ben had '''
    with open(filename, 'r') as f:
        counts = Counter()
        lines = []
        for line in f.readlines():
            tokenized = word_tokenize(line.lower())
            counts.update(tokenized)
            lines.append(tokenized)
        
        vocab = [w for w in counts if counts[w] >= 2]
        vocab.append('')
        vocab.append('#')
        vocab_map = dict(map(reversed, enumerate(vocab)))
        windows = []
        labels = []

        for line in lines:
            line = [vocab_map.get(word, vocab_map['']) for word in line]
            for win, label in window(line, window_size, vocab_map['#']):
                windows.append(win)
                labels.append(label)
        return vocab_map, tf.constant(windows), tf.constant(labels)
    


    
def perplexity(y_true, y_pred):
    ''' Compute the perplexity of the model. '''
    ce = tf.reduce_mean(tf.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True))
    return tf.exp(ce)
   

WINDOW_SIZE = 6
HIDDEN_SIZE = 60
NUM_EPOCHS = 50
EMBED_DIM = 60
HIDDEN_DIM = 60
BATCH_SIZE = 100
SEED = 31415

vocab_map, windows, labels = load_data('brown.txt', WINDOW_SIZE)

# Shuffle the window and label tensors.
# The initial shuffling will determine the train/val/test split. 
# The variable SEED controls what shuffle is produced.
tf.random.set_seed(SEED)
indices = tf.range(0, windows.shape[0], dtype=tf.int32)
shuffle = tf.random.shuffle(indices)
windows = tf.gather(windows, shuffle)
labels = tf.gather(labels, shuffle)


# This code splits the dataset into train/validation/test. 
# The way it's split is as follows:
#            train (64%)       val (16%)   test (20%)   
#  <-------------------------><-------><----------->
# [.................................................] (dataset)
# 
# Tweak TRAIN_VAL_SPLIT and VAL_SPLIT to change the proportion.
n = windows.shape[0]
split = int(0.8 * n)
val_split = int(0.8 * split)
train_windows = windows[:val_split]
train_labels = labels[:val_split]
val_windows = windows[val_split:split]
val_labels = labels[val_split:split]
test_windows = windows[split:]
test_labels = labels[split:]

# Checkpointing is super useful for making sure your progress isn't lost over a few hours. 
# Basically it'll save your weights to disk and then can load them in case one epoch looks interesting or your computer dies. 
# Check out https://keras.io/api/callbacks/model_checkpoint/ for details on the flags you can configure for this.

checkpoint = keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.2f}.h5',
                                            # This will only save the model if it beats the best validation accuracy.
                                            # Disable this to save more but use more disk space.
                                            save_best_only=True)

vocab_size = len(vocab_map) + 1

model = BengioModel(WINDOW_SIZE, vocab_size, embedding_size=EMBED_DIM, hidden_units=HIDDEN_DIM)
# Because BengioModel subclasses the keras Model class you can do all sorts of interesting things with it.
# Check out https://keras.io/api/models/model/ for a list of supported methods and properties.
model.compile(optimizer='adam', 
              loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=[perplexity]
               )
model.fit(train_windows, train_labels,
          # Comment this line to disable checkpointing
          callbacks=[checkpoint],
          batch_size=BATCH_SIZE,
          validation_data=(val_windows, val_labels),
          epochs=NUM_EPOCHS)
model.evaluate(test_windows, test_labels)


Leave a Reply

Your email address will not be published. Required fields are marked *