LSTMCell

class paddle.fluid.layers. LSTMCell ( hidden_size, param_attr=None, bias_attr=None, gate_activation=None, activation=None, forget_bias=1.0, dtype='float32', name='LSTMCell' ) [源代码]

长短期记忆单元(Long-Short Term Memory)。通过对 fluid.contrib.layers.rnn_impl.BasicLSTMUnit 包装,来让它可以应用于RNNCell。

公式如下:

\[\begin{split}i_{t} &= act_g \left ( W_{x_{i}}x_{t}+W_{h_{i}}h_{t-1}+b_{i} \right ) \\ f_{t} &= act_g \left ( W_{x_{f}}x_{t}+W_{h_{f}}h_{t-1}+b_{f}+forget\_bias \right ) \\ c_{t} &= f_{t}c_{t-1}+i_{t}act_h\left ( W_{x_{c}}x_{t} +W_{h_{c}}h_{t-1}+b_{c}\right ) \\ o_{t} &= act_g\left ( W_{x_{o}}x_{t}+W_{h_{o}}h_{t-1}+b_{o} \right ) \\ h_{t} &= o_{t}act_h \left ( c_{t} \right )\end{split}\]

更多细节可以参考 RECURRENT NEURAL NETWORK REGULARIZATION

参数

  • hidden_size (int) - LSTMCell中的隐藏层大小。

  • param_attr (ParamAttr,可选) - 指定权重参数属性的对象。默认值为None,表示使用默认的权重参数属性。具体用法请参见 ParamAttr

  • bias_attr (ParamAttr,可选) - 指定偏置参数属性的对象。默认值为None,表示使用默认的偏置参数属性。具体用法请参见 ParamAttr

  • gate_activation (function,可选) - \(act_g\) 的激活函数。默认值为 fluid.layers.sigmoid

  • activation (function,可选) - \(act_c\) 的激活函数。默认值为 fluid.layers.tanh

  • forget_bias (float,可选) - 计算遗忘们时使用的遗忘偏置。默认值为 1.0。

  • dtype (string,可选) - 此Cell中使用的数据类型。默认值为 float32

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

LSTMCell类的实例对象。

代码示例

import paddle.fluid.layers as layers
cell = layers.LSTMCell(hidden_size=256)

方法

call(inputs, states)

执行GRU的计算。

参数

  • input (Variable) - 输入,形状为 \([batch\_size,input\_size]\) 的tensor,对应于公式中的 \(x_t\)。数据类型应为float32。

  • states (Variable) - 状态,包含两个tensor的列表,每个tensor形状为 \([batch\_size,hidden\_size]\)。对应于公式中的 \(h_{t-1}, c_{t-1}\)。数据类型应为float32。

返回 一个元组 (outputs, new_states),其中 outputs 是形状为 \([batch\_size,hidden\_size]\) 的tensor,对应于公式中的 \(h_{t}\)new_states 是一个列表,包含形状为 \([batch_size,hidden_size]\) 的两个tensor变量,它们对应于公式中的 \(h_{t}, c_{t}\)。这些tensor的数据类型都与 state 的数据类型相同。

返回类型 tuple

state_shape()

LSTMCell的 state_shape 是一个具有两个形状的列表:\([[hidden\_size], [hidden\_size]]\) (batch大小为-1,自动插入到形状中)。这两个形状分别对应于公式中的 \(h_{t-1}\) and \(c_{t-1}\)

参数 无。

返回 LSTMCell的 state_shape

返回类型 list