When recurrent models don't
need to be recurrent
Moritz Hardt
Joint work with John Miller
Sequence problems are everywhere.
Machine translation
Video captioning
Video captioning
Caption + translate
Video + speech synthesis
Conventional wisdom
Sequence problems are best solved by recurrent models.
"For most deep learning practitioners, sequence modeling is synonymous with
recurrent networks."
— Bai, Kolter, Koltun (2018)
Recurrent models
Maintain hidden state ht according to:
ht+1yt==ϕw(ht,xt)f(ht)
System specified by parameters w∈ℝd.
Typical setting ϕw(ht,xt)=tanh(Aht+Bxt). Here, w=(A,B).
Advantages of recurrent models
-
Expressive model class
-
Flexible formalism to cope with variable sequence length
-
Potential to capture long-term dependencies
-
Historical and cultural heritage (the watch your father gave you)
Training recurrent models
Update model parameters with gradient descent:
w←w−α∇wℓ(w;(x,y))
Get derivatives using back-propagation
Need to store all t hidden states for backward
pass!
Shortcomings of recurrent models
-
Finicky during training (vanishing/exploding gradients)
-
Extensive tuning necessary
-
Backprop prohibitive on long sequences
-
Truncation often necessary in practice
Empirical reality
Increasingly, feed-forward models replace recurrent models in various
applications.
Language Modeling (Gated-Conv by Dauphin et al.),
Machine Translation (Transformer by Vaswani et al.),
Speech Synthesis (WaveNet by van den Oord et al.),
everything else (Bai et al.)
Question
Are all trainable recurrent models
inherently feed-forward?
Sounds plausible to me, but it's hard
to characterize trainability
Main result
Stable recurrent models can be replaced by equivalent
feed-forward models for both training and inference.
Put differently, either model is
inherently feed-forward or it is unstable.
Empirical work
Ways of making models stable (RNNs, LSTMs), also during training
Stable recurrent models can achieve
competitive performance on various tasks.
Stability
State-transition map
ϕ is
λ-contractive if
‖ϕ(h,x)−ϕ(h′,x)‖≤λ‖h−h′‖
Stable: λ<1
Stability
Linear dynamical system ϕ(h,x)=Ah+Bx is stable if
‖A‖<1.
Same for recurrent
recurrent neural network
ϕ(h,x)=tanh(Ah+Bx)
(since tanh is 1-Lipschitz.)
And, since you're asking,
LSTM is stable if
holds.
Truncated model
Model truncated at depth
k:
hkt+1hkt−kykt===ϕw(hkt,xt)0f(hkt)
Autoregressive feed-forward model of depth k times original depth
Inference
Theorem.
Assume ϕw is λ-contractive for λ<1 and Lipschitz.
Then, ‖yt−ykt‖≤ϵ
provided k≫log(1/(1−λ)ϵ).
Training with gradient descent
Let wN be parameters of recurrent model
after N steps of gradient descent with O(1/n) learning rate.
Let wkN be same but for truncated model.
Theorem.
Under stability, smoothness, Lipschitz assumptions,
‖wN−wkN‖≤ϵ provided k≫log(N/ϵ).
Training vs inference
Truncation affects neither training nor inference.
Image credit: Song Han
Contrast with weight pruning
Preserves accuracy when done after training
Can hurt performance when done before training
Proof idea
Inference:
Direct calculation using contractiveness assumption.
Training:
Uses stability properties of gradient descent.
Inspired by "train
faster, generalize better" analysis.
[H-Recht-Singer, 2015]
Proof of inference result
If ϕw is λ-contractive, then:
‖ht−hkt‖≤λk‖ht−k−hkt−k‖=λk‖ht−k‖
Now assume ϕw is L-Lipschitz in xt, and ‖xt‖≤B.
‖ht‖≤λ‖ht−1‖+LB≤∑i=0tλiLB≤LB1−λ.
Hence,
‖ht−hkt‖≤λkLB1−λ.
Proof of training result
Step 1:
Show that stable models have vanishing gradients (over time).
I.e., argue that ∇wℓ≈∇wℓk±ϵ
(not so obvious) extension of inference result
Step 2:
Show that gradient descent insensitive to gradient differences.
Connection to algorithmic stability of gradient descent
Connection to algorithmic stability
Gradient descent is insensitive to pertubations in the
sample (algorithmic stability) [H-Recht-Singer, 2015]
Similar analysis applies here
Decaying learning rate necessary for analysis in both cases.
Here, even necessary in experiment.
Word-level language modeling
-
Language modeling on Wikitext-2. Similar to Penn Treebank data, but bigger.
-
One giant sequence.
-
Task: Predict next word.
Stable RNN slightly better than unstable RNN. Opposite for LSTM.
Polyphonic music modeling
-
Corpus of Bach chorales.
-
Each sequence encodes which piano keys were pressed at a given time.
-
Task: Predict the next code in the sequence.
Stable/unstable LSTMs and RNNs have very similar performance.
Stable slightly better.
Character-level language modeling
-
Language modeling on Tolstoy's War and Peace.
-
One giant sequence.
-
Task: Predict next character.
Stable worse than unstable. No difference between LSTM and RNN.
Summary
Stable recurrent models do not provide long-term memory and can be replaced by
feed-forward models without loss.
Empirical evidence suggests that stable models can achieve good performance.
Some speculation
Are all efficiently trainable recurrent models
inherently subsumed by
feed-forward models?
In this work: Stability as a proxy for trainability.
Future work
When are recurrent models really needed (if they ever are)?
What is the price of stability (if there is any)?
What memory mechanisms are efficiently learnable?
What kind of memory do natural problems require?
What classes of dynamical systems are learnable?
(Tricky even for linear dynamical systems)
Decaying learning rate necessary
When recurrent models don't
need to be recurrent
Moritz Hardt
Joint work with John Miller