PyTorch 学习笔记

使用 PyTorch 的一些笔记,以防写完就忘,看完 API 又想起来,长此以往。

torch.nn

torch.nn.LSTM

LSTM 中的 hidden state 其实就是指每一个 LSTM cell 的输出,而 cell state 则是每次传递到下一层的「长时记忆」,我总觉得这个名字起的特别别扭,所以总不能很好的理解。下面这张图能更好的说明这些变量的意义。

LSTM

再来简单的回顾一下 LSTM 的几个公式

其中 $h_t$ 和 $c_t$ 就是所谓的 hidden statecell state 了。可以看到 LSTM 中所谓的 output gate,即 $o_t$ 其实是中间状态,它和 cell state 经过 $\tanh$ 相乘,得到了 hidden state,也就是输出值。

PyTorch 中 LSTM 的输出结果是一个二元组套二元组 (output, (h_n, c_n))。第一个 output 是每一个 timestamp 的输出,也就是每一个 cell 的 hidden state。第二个输出是一个二元组,分别表示最后一个 timestamp 的 hidden statecell state。因此,如果把 h_nc_n 记录下来,就可以保留整个 LSTM 的状态了。

PyTorch 中可以通过 bidirectional=True 来方便的将 LSTM 设置为双向,此时 output 会自动把每一个 timestamp 的正向和反向 LSTM 拼在一起。而 h_nc_n 的第一维长度会变为 2(单向是长度为 1)。而且此时有

即正向 output 的最后一个 timestemp(对应 LSTM 的最后一个 cell)的输出和正向的 hidden state 相同,反向 output 的最后一个 timestamp(对应 LSTM 的第一个 cell)的输出和反向的 hidden state 相同。

此外,在 PyTorch 中,LSTM 输出的形状和别的框架不太一样,它是序列长度优先的,(seq_len, batch_size, hz),如果觉得不习惯,可以通过 batch_first=True 来设定为 batch_size 优先。

评论