This is an interactive creative-coding experiment - draw slowly with your touch to see how the LSTM responds:
(Or click here to scroll to the blog post below)
LSTM + Touch drawing
This uses an LSTM (Long Short-Term Memory) - a type of RNN (recurrent neural network). It takes your touch input and provides a curvy, fun/smooth output. This is similar to splines or Bezier curves but with a bit more natural, playful personality attached to it.
Side note on LSTMs/other experiments
I came across LSTMs a while ago at work in the context of keyboard / text predictions, but I never really understood them. More recently, I started exploring RL with Puffer where LSTMs (and other RNNs like GRU) are important parts of the architecture. (More on that at a later date)
Although I lack a formal ML background, I found LSTMs both fascinating and confusing. So I built this demo to experience an LSTM right under your finger and visualize its states in a concrete manner. It also helps me visualize its long-term memory/states better as you move your finger across the screen - you will see the curves display a “personality” if you drag slowly across the screen. A great primer on LSTMs is here by Christopher Olah (a co-founder of Anthropic) which also covers the math behind it (which is beyond me).
There are other great experiments such as SketchRNN, Quick Draw! that (likely) used large datasets + labeled features to “teach” the RNN. For folks like me, without access to large datasets/human labeling: I thought, well, human touch interaction generates a lot of data and is readily available right under your fingertips.
I used touch input to generate lots of time-series training data for the LSTM without requiring labeled data, and that can train/infer on a budget phone/computer. By shifting/moving a window over a series of continuous touch time-series data we can effectively use supervised learning without requiring explicit labels/features during back-propagation.
Training notes
I trained this by first drawing circle-ish curves using my own touch - you can see this personality (of circle-ish curves) peek through via the LSTM output during inference. Each batch of training data is a sliding window of time-series data: i.e. [(dx1, dy1, dt1)... (dxn, dyn, dtn)] as input and the RNN predicts the next part of the curve (dx, dy, dt).
The network has a tiny two-layer LSTM network (with a simple input layer and a dense output layer). I chose this network to keep things simple while maximizing the ‘fun’/interesting element and to be able to run fast-enough on a budget phone/computer. This took about 10 mins (90 epochs) to train on a decent desktop using WebGL TensorFlow.js.
For inference I use TF.js WASM as it’s faster for smaller data and doesn’t hog the GL context during drawing (and has multi-thread support). I use TF.js instead of C++/libtorch here purely for fast iteration and to step away from libtorch/C++ code that I have been working on for quite some time.
During inference, I capture the touch input and use the diff between the current frame (x, y, t) - previous frame (x, y, t) as the input to the neural network. To make it more interesting, I tweak the curve (diff) a little bit towards the touch point (i.e. your finger) so that the curve stays on the screen. Otherwise, the LSTM output veers off the screen like it has a mind of its own - it still is interesting but also a bit annoying as you really have to pull the curve back onto the screen. :-)
Visualizing the LSTM states
The LSTM network looks like this:
A feedforward neural network differs from an RNN in that it does not carry hidden state across timesteps. On the other hand, an LSTM (which is a recurrent neural network) carries the hidden states across timesteps (via a side channel we explicitly maintain). H/C refer to hidden/long-term memory states that store relevant long-term information and the gates are computed from the current input and previous hidden state, while the cell state carries longer-term memory.
It’s a bit abstract, so here are some fun creative-coded visualizations:
(Draw below to see how the LSTM ‘thinks’)
There are 8 H/C parameters per LSTMCell here, and I visualize all 16 pairs of them. Some of the H/C follow the X or Y direction clearly; if you stare long enough you can probably see the long-term memory (if you draw a line followed by a curve for example).
Another fun visualization
Draw slowly here:
This shows the H/C state curves laid out on top of the input curve. I thought it looked cool, but there really isn’t a point.
This is just a fun way to look at the LSTM internals - mostly because the math behind it confounded me and it felt like a ‘mystery machine’ to me.
I will be writing about reinforcement learning in future blog posts. Follow me on X for updates.
Code + demo links
Code for the training+inference+drawing (TypeScript) is here hosted using this HTML (run with tsc -p and open with touch-draw.html?training=true&epochs=30).