好玩的人工智能
快乐的深度学习

tf.nn.rnn_cell.LSTMCell中num_units参数解释

本文只是介绍tensorflow中的tf.nn.rnn_cell.LSTMCell中num_units,关于LSTM和如何使用请看前言的教程。
在使用Tensorflow跑LSTM的试验中, 有个num_units的参数,这个参数是什么意思呢?

先总结一下,num_units这个参数的大小就是LSTM输出结果的维度。例如num_units=128, 那么LSTM网络最后输出就是一个128维的向量。

我们先换个角度举个例子,最后再用公式来说明。

假设在我们的训练数据中,每一个样本 x 是 28*28 维的一个矩阵,那么将这个样本的每一行当成一个输入,通过28个时间步骤展开LSTM,在每一个LSTM单元,我们输入一行维度为28的向量,如下图所示。

那么,对每一个LSTM单元,参数 num_units=128 的话,就是每一个单元的输出为 128*1 的向量,在展开的网络维度来看,如下图所示,对于每一个输入28维的向量,LSTM单元都把它映射到128维的维度, 在下一个LSTM单元时,LSTM会接收上一个128维的输出,和新的28维的输入,处理之后再映射成一个新的128维的向量输出,就这么一直处理下去,知道网络中最后一个LSTM单元,输出一个128维的向量。

从LSTM的公式的角度看是什么原理呢?我们先看一下LSTM的结构和公式:

参数 num_units=128 的话,

  1. 对于公式 (1) ,h=128*1 维, x=28*1 维,[h,x]便等于156*1 维,W=128*156 维,所以 W*[h,x]=128*156 * 156*1=128*1, b=128*1 维, 所以 f=128*1+128*1=128*1 维;
  2. 对于公式 (2) 和 (3),同上可分析得 i=128*1 维,C(~)=128*1 维;
  3. 对于公式 (4) ,f(t)=128*1, C(t-1)=128*1, f(t) .* C(t-1) = 128*1 .* 128*1 = 128*1 , 同理可得 C(t)=128*1 维;
  4. 对于公式 (5) 和 (6) , 同理可得 O=128*1 维, h=O.*tanh(C)=128*1 维。

所以最后LSTM单元输出的h就是 128*1 的向量。

另外几个需要注意的地方:

1、 cell 的状态是一个向量,是有多个值的。。。一开始没有理解这点的时候怎么都想不明白

2、 上一次的状态 h(t-1)是怎么和下一次的输入 x(t) 结合(concat)起来的,这也是很多资料没有明白讲的地方,也很简单,concat, 直白的说就是把二者直接拼起来,比如 x是28位的向量,h(t-1)是128位的,那么拼起来就是156位的向量,就是这么简单。。

3、 cell 的权重是共享的,这是什么意思呢?这是指这张图片上有三个绿色的大框,代表三个 cell 对吧,但是实际上,它只是代表了一个 cell 在不同时序时候的状态,所有的数据只会通过一个 cell,然后不断更新它的权重。

4、那么一层的 LSTM 的参数有多少个?根据第 3 点的说明,我们知道参数的数量是由 cell 的数量决定的,这里只有一个 cell,所以参数的数量就是这个 cell 里面用到的参数个数。假设 num_units 是128,输入是28位的,那么根据上面的第 2 点,可以得到,四个小黄框的参数一共有 (128+28)*(128*4),也就是156 * 512,可以看看 TensorFlow 的最简单的 LSTM 的案例,中间层的参数就是这样,不过还要加上输出的时候的激活函数的参数,假设是10个类的话,就是128*10的 W 参数和10个bias 参数

5、cell 最上面的一条线的状态即 s(t) 代表了长时记忆,而下面的 h(t)则代表了工作记忆或短时记忆

未经允许不得转载:零点智能 » tf.nn.rnn_cell.LSTMCell中num_units参数解释
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

零点智能 人工智能社区,加Q群:469331966

投稿&建议&加Q群