Photo by Maarten Deckers on UnsplashTo understand LSTM architecture, code a forward pass with just NumPyGreg ConditBlockedUnblockFollowFollowingMay 29If you’re like me, you learn best by starting simple and building from the ground up.
No matter how often I read colah’s famous LSTM post, or Karpathy’s post on RNNs (great resources!), the LSTM network architecture seemed overly complicated and the gates were hazy.
Eventually, I wrote an LSTM forward pass with just NumPy.
I highly recommend doing this exercise, because you can examine the hidden and cell states, inputs and outputs, and clearly see how they were created.
In this walkthrough, we will:Represent each of the LSTM gates with its own distinct Python functionChain them together logically to create a NumPy Network with a single LSTM cellHere’s the point: these functions can be edited to include print statements, which allows examination of the shapes of your data, the hidden state (short term memory), and the cell state (long term memory) at various stages throughout the LSTM cell.
This will help you understand why the gates exist.
Secondly, to confirm we’ve done it right, we will:Set up a small Neural Network with a single LSTM cell using PyTorchInitialize both networks with the same, random weightsMake 1 forward pass with both networks, and check that the output is the sameLet’s go!Part 1: Creating the NumPy NetworkBelow is the LSTM Reference Card.
It contains the Python functions, as well as an important diagram.
On this diagram can be found every individual operation and variable (inputs, weights, states) from the LSTM gate functions.
They are color-coded to match the gate they belong to.
The code for the functions can be copied below the card.
I highly recommend saving this reference card, and using this to analyze and understand LSTM architecture.
(A printable pdf version is available for download here.
)Here’s the copy-able code from the above reference card — one function for each gate:Typically, an LSTM feeds a final, fully-connected linear layer.
Let’s do that as well:Part 2: Compare to a PyTorch LSTMCreate a PyTorch LSTM with the same parameters.
PyTorch will automatically assign the weights with random values — we’ll extract those and use them to initialize our NumPy network as well.
Don’t get overwhelmed!.The PyTorch documentation explains all we need to break this down:The weights for each gate in are in this order: ignore, forget, learn, outputkeys with ‘ih’ in the name are the weights/biases for the input, or Wx_ and Bx_keys with ‘hh’ in the name are the weights/biases for the hidden state, or Wh_ and Bh_Given the parameters we chose, we can therefore extract the weights for the NumPy LSTM to use in this way:Now, we have two networks — one in PyTorch, one in NumPy — with access to the same starting weights.
We’ll put some time series data through each to ensure they are identical.
To do a forward pass with our network, we’ll pass the data into the LSTM gates in sequence, and print the output after each event:Good News!.Putting the same data through the PyTorch model shows that we return identical output:We can additionally verify that after the data has gone through the LSTM cells, the two models have the same hidden and cell states:I hope this helps build an intuition for how LSTM networks make predictions!.You can copy the whole code at once from this gist.
The LSTM Reference Card can be downloaded from this page, including printable versions.
Thanks to Mike Ricos for collaborating on the creation of this great resource.
.. More details