When recurrent models don't
need to be recurrent
Moritz Hardt
Joint work with John Miller
Sequence problems are everywhere.
Machine translation
Video captioning
Video captioning
Caption + translate
Video + speech synthesis
Conventional wisdom
Sequence problems are best solved by recurrent models.
"For most deep learning practitioners, sequence modeling is synonymous with
recurrent networks."
— Bai, Kolter, Koltun (2018)
Recurrent models
Maintain hidden state ht${h}_{t}$ according to:
ht+1yt==ϕw(ht,xt)f(ht)$$\begin{array}{rll}{h}_{t+1}& =& {\varphi}_{w}({h}_{t},{x}_{t})\\ {y}_{t}& =& f({h}_{t})\end{array}$$
System specified by parameters w∈ℝd.$w\in {\mathbb{R}}^{d}.$
Typical setting ϕw(ht,xt)=tanh(Aht+Bxt)${\varphi}_{w}({h}_{t},{x}_{t})=\mathrm{t}\mathrm{a}\mathrm{n}\mathrm{h}(A{h}_{t}+B{x}_{t})$. Here, w=(A,B).$w=(A,B).$
Advantages of recurrent models

Expressive model class

Flexible formalism to cope with variable sequence length

Potential to capture longterm dependencies

Historical and cultural heritage (the watch your father gave you)
Training recurrent models
Update model parameters with gradient descent:
w←w−α∇wℓ(w;(x,y))$w\leftarrow w\alpha {\mathrm{\nabla}}_{w}\ell (w;(x,y))$
Get derivatives using backpropagation
Need to store all t$t$ hidden states for backward
pass!
Shortcomings of recurrent models

Finicky during training (vanishing/exploding gradients)

Extensive tuning necessary

Backprop prohibitive on long sequences

Truncation often necessary in practice
Empirical reality
Increasingly, feedforward models replace recurrent models in various
applications.
Language Modeling (GatedConv by Dauphin et al.),
Machine Translation (Transformer by Vaswani et al.),
Speech Synthesis (WaveNet by van den Oord et al.),
everything else (Bai et al.)
Question
Are all trainable recurrent models
inherently feedforward?
Sounds plausible to me, but it's hard
to characterize trainability
Main result
Stable recurrent models can be replaced by equivalent
feedforward models for both training and inference.
Put differently, either model is
inherently feedforward or it is unstable.
Empirical work
Ways of making models stable (RNNs, LSTMs), also during training
Stable recurrent models can achieve
competitive performance on various tasks.
Stability
Statetransition map
ϕ$\varphi $ is
λ$\lambda $contractive if
‖ϕ(h,x)−ϕ(h′,x)‖≤λ‖h−h′‖$$\Vert \varphi (h,x)\varphi ({h}^{\prime},x)\Vert \le \lambda \Vert h{h}^{\prime}\Vert $$
Stable: λ<1$\lambda <1$
Stability
Linear dynamical system ϕ(h,x)=Ah+Bx$\varphi (h,x)=Ah+Bx$ is stable if
‖A‖<1.$\Vert A\Vert <1.$
Same for recurrent
recurrent neural network
ϕ(h,x)=tanh(Ah+Bx)$\varphi (h,x)=\mathrm{t}\mathrm{a}\mathrm{n}\mathrm{h}(Ah+Bx)$
(since tanh$\mathrm{t}\mathrm{a}\mathrm{n}\mathrm{h}$ is 1$1$Lipschitz.)
And, since you're asking,
LSTM is stable if
holds.
Truncated model
Model truncated at depth
k$k$:
hkt+1hkt−kykt===ϕw(hkt,xt)0f(hkt)$$\begin{array}{rll}{h}_{t+1}^{k}& =& {\varphi}_{w}({h}_{t}^{k},{x}_{t})\\ {h}_{tk}^{k}& =& 0\\ {y}_{t}^{k}& =& f({h}_{t}^{k})\end{array}$$
Autoregressive feedforward model of depth k$k$ times original depth
Inference
Theorem.
Assume ϕw${\varphi}_{w}$ is λ$\lambda $contractive for λ<1$\lambda <1$ and Lipschitz.
Then, ‖yt−ykt‖≤ϵ$\Vert {y}_{t}{y}_{t}^{k}\Vert \le \u03f5$
provided k≫log(1/(1−λ)ϵ).$k\gg \mathrm{log}(1/(1\lambda )\u03f5).$
Training with gradient descent
Let wN${w}_{N}$ be parameters of recurrent model
after N$N$ steps of gradient descent with O(1/n)$O(1/n)$ learning rate.
Let wkN${w}_{N}^{k}$ be same but for truncated model.
Theorem.
Under stability, smoothness, Lipschitz assumptions,
‖wN−wkN‖≤ϵ$\Vert {w}_{N}{w}_{N}^{k}\Vert \le \u03f5$ provided k≫log(N/ϵ).$k\gg \mathrm{log}(N/\u03f5).$
Training vs inference
Truncation affects neither training nor inference.
Image credit: Song Han
Contrast with weight pruning
Preserves accuracy when done after training
Can hurt performance when done before training
Proof idea
Inference:
Direct calculation using contractiveness assumption.
Training:
Uses stability properties of gradient descent.
Inspired by "train
faster, generalize better" analysis.
[HRechtSinger, 2015]
Proof of inference result
If ϕw${\varphi}_{w}$ is λ$\lambda $contractive, then:
‖ht−hkt‖≤λk‖ht−k−hkt−k‖=λk‖ht−k‖$$\Vert {h}_{t}{h}_{t}^{k}\Vert \le {\lambda}^{k}\Vert {h}_{tk}{h}_{tk}^{k}\Vert ={\lambda}^{k}\Vert {h}_{tk}\Vert $$
Now assume ϕw${\varphi}_{w}$ is L$L$Lipschitz in xt${x}_{t}$, and ‖xt‖≤B.$\Vert {x}_{t}\Vert \le B.$
‖ht‖≤λ‖ht−1‖+LB≤∑i=0tλiLB≤LB1−λ.$$\Vert {h}_{t}\Vert \le \lambda \Vert {h}_{t1}\Vert +LB\le \sum _{i=0}^{t}{\lambda}^{i}LB\le \frac{LB}{1\lambda}\phantom{\rule{thinmathspace}{0ex}}.$$
Hence,
‖ht−hkt‖≤λkLB1−λ.$\Vert {h}_{t}{h}_{t}^{k}\Vert \le \frac{{\lambda}^{k}LB}{1\lambda}\phantom{\rule{thinmathspace}{0ex}}.$
Proof of training result
Step 1:
Show that stable models have vanishing gradients (over time).
I.e., argue that ∇wℓ≈∇wℓk±ϵ${\mathrm{\nabla}}_{w}\ell \approx {\mathrm{\nabla}}_{w}{\ell}^{k}\pm \u03f5$
(not so obvious) extension of inference result
Step 2:
Show that gradient descent insensitive to gradient differences.
Connection to algorithmic stability of gradient descent
Connection to algorithmic stability
Gradient descent is insensitive to pertubations in the
sample (algorithmic stability) [HRechtSinger, 2015]
Similar analysis applies here
Decaying learning rate necessary for analysis in both cases.
Here, even necessary in experiment.
Wordlevel language modeling

Language modeling on Wikitext2. Similar to Penn Treebank data, but bigger.

One giant sequence.

Task: Predict next word.
Stable RNN slightly better than unstable RNN. Opposite for LSTM.
Polyphonic music modeling

Corpus of Bach chorales.

Each sequence encodes which piano keys were pressed at a given time.

Task: Predict the next code in the sequence.
Stable/unstable LSTMs and RNNs have very similar performance.
Stable slightly better.
Characterlevel language modeling

Language modeling on Tolstoy's War and Peace.

One giant sequence.

Task: Predict next character.
Stable worse than unstable. No difference between LSTM and RNN.
Summary
Stable recurrent models do not provide longterm memory and can be replaced by
feedforward models without loss.
Empirical evidence suggests that stable models can achieve good performance.
Some speculation
Are all efficiently trainable recurrent models
inherently subsumed by
feedforward models?
In this work: Stability as a proxy for trainability.
Future work
When are recurrent models really needed (if they ever are)?
What is the price of stability (if there is any)?
What memory mechanisms are efficiently learnable?
What kind of memory do natural problems require?
What classes of dynamical systems are learnable?
(Tricky even for linear dynamical systems)
Decaying learning rate necessary
When recurrent models don't
need to be recurrent
Moritz Hardt
Joint work with John Miller