model.eval() 和 with torch.no_grad() 对比

文章目录
  1. 1. model.eval()
    1. 1.1. torch.no_grad()
  2. 2. 参考资料

两者目标不一样

model.eval()

eval()将通知所有层当前的使用模式(train/eval), 因为一些特殊的层会有两种工作方式.

  • batchnorm 在train时使用batch的数据归一化, 而在eval时使用统计数据进行归一化
  • dropout 在train时正常, 在eval时不会生效

torch.no_grad()

1
2
with torch.no_grad():
y = model(x)

no_grad()会关闭自动梯度引擎, 这会降低内存或显存的使用数量, 并且加快计算速度

参考资料