recompute

paddle.distributed.fleet.utils. recompute ( function, *args, **kwargs ) [源代码]

重新计算中间激活函数值来节省显存。

参数

  • function (paddle.nn.Sequential) - 模型前向传播的部分连续的层函数组成的序列, 它们的中间激活函数值将在前向传播过程中被释放掉来节省显存,并且在反向梯度计算的时候会重新被计算。

  • args (Tensor) - function的输入。

  • kwargs (Dict) - kwargs只应该包含preserve_rng_state的键值对,用来表示是否保存前向的rng,如果为True,那么在反向传播的重计算前向时会还原上次前向的rng值。默认preserve_rng_state为True。

返回

function作用在输入的输出

代码示例

COPY-FROM: paddle.distributed.fleet.utils.recompute