Gated Recurrent Units - GRU4REC recommender system

What is GRU4REC ?

  • GRU4REC is a recurrent neural network architecture that uses gated recurrent units (GRU) and is aimed for acting as a session-based recommender system. It caught my eye since there are cases in the industry where a user cannot be identified via an ID (it might not be registered) so the only feasible way to identify such a user is using a session. It is also one of the few applications of RNN’s to the recommendation field. You can find the paper here.

How does it work?

  • One of the usages of RNN’s is to model data that comes in a sequence with some length. An example is to attempt to suggest the next word of a given context phrase in a book, by analyzing and processing all the phrases of the books of the same author. The sequence often implies a temporal series of events. RNN’s differ from standard feed-forward neural networks since they maintain a hidden state, which can be interpreted as a memory of previous inputs. This hidden state acts as a loop which unfolds the hidden layer N times, where N is the length of the sequence we’re analyzing. This way, we calculate the output of the network by using the input at each step, and the hidden state at that same step which is learned by all the other previous steps.
  • Since RNN’s are trained using back-propagation they suffer from the Vanishing Gradient issue, This issue is related to the fact that, the weights of the network are updated using the gradient of the computed loss, and that gradient shrinks as the error propagates through the network at each time step, meaning that with lower gradient values, there will be smaller weight updates and the RNN won’t learn the long-range dependencies across time steps.

rnn

  • Gated Recurrent Units help RNN’s to mitigate the previous issue. They do so by applying different ways of calculating the hidden state, by deciding when to calculate it and how to update it. They implement a reset gate(r(t)), an update gate(z(t)) and a new memory container(h'(t)).

gru

How can we implement this?

  • If we consider a session (user) that visits (or purchases) a series of products at a specific timestamp, what we have is a sequence of product IDS in a session which fits our use case.

  • To simplify things we can look at the following Tensorflow link and apply it to our use-case.

from tensorflow.keras import Model
from tensorflow.keras.layers import Embedding, GRU, Dropout, Dense
from tensorflow.keras.losses import SparseCategoricalCrossentropy

class GRU4Rec(Model):
    def __init__(self, vec_size, embedding_dim, rnn_units, dropout):
        super().__init__(self)
        self.embedding = Embedding(vec_size, embedding_dim)
        self.gru = GRU(rnn_units, return_sequences=True,  return_state=True)
        self.dropout = Dropout(dropout)
        self.dense = Dense(vec_size, activation='softmax')

    def call(self, inputs, states=None, return_state=False, training=False):
        x = self.embedding(inputs, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)

        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dropout(x)
        x = self.dense(x, training=training)
        if return_state:
            return x, states
        else:
            return x

model = GRU4Rec(
    vocab_size=len(n_items),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

loss = SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
model.fit(data)

How can we deploy such a model?

  • I’ve tackled this problem and one efficient service I’ve used was AWS Sagemaker since it provides a cloud framework to train and deploy deep learning models. One thing you need to notice is that, if this is going to be served online, it needs to respond to HTTP requests. To simplify this I’ve used Flask as a server and used Gunicorn and Nginx on top of it. You need to use build a Docker image where you’ll provided your Flask server so that AWS Sagemaker can pull that image and serve your model endpoint.. You can store your images using AWS Elastic Container Registry.

gru