debug:pytorch_梯度出现NaN

文章目录
  1. 1. 梯度出现异常值:NaN
    1. 1.1. 定位方法:
    2. 1.2. 原因:
    3. 1.3. 官方文档参考

计算

梯度出现NaN

梯度出现异常值:NaN

定位方法:

使用如下代码设置,在出现NaN异常时程序会报错,便于定位错误代码

1
2
3
4
5
6
7
import torch
# 正向传播时:开启自动求导的异常侦测
torch.autograd.set_detect_anomaly(True)

# 反向传播时:在求导时开启侦测
with torch.autograd.detect_anomaly():
loss.backward()

原因:

很多网友提到,pytorch的求标准差的函数STD可能有问题。如果使用了类似会调用STD函数的各种Norm层就可能导致NAN问题。

官方文档参考