权重初始化 - PyTorch实现

之前在cs231n上学习了简单的几种初始化方法 - 权重初始化。最近阅读源码时发现了PyTorch实现的权重初始化

实现

torchvision.models.resnet实现的权重初始化

1
2
3
4
5
6
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

torchvision.models.squeezenet实现的权重初始化

1
2
3
4
5
6
7
8
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)

torchvision.models.googlenet实现的权重初始化

1
2
3
4
5
6
7
8
9
10
11
12
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
X = stats.truncnorm(-2, 2, scale=0.01)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

torchvision.models.mobilenet实现的权重初始化

1
2
3
4
5
6
7
8
9
10
11
12
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

torch.nn.init

torch.nn.init中实现了权重初始化方法。常用的初始化方法包括

  1. xavier_uniform_/xavier_normal_
  2. kaiming_normal_/kaiming_uniform_
  3. zeros_/ones_/constant_
  4. torch.nn.init.normal_

xavier_uniform_/xavier_normal_

1
2
3
4
# xavier均匀分布
torch.nn.init.xavier_uniform_(tensor, gain=1.0)
# xavier正态分布
torch.nn.init.xavier_normal_(tensor, gain=1.0)

这两个方法来自于论文Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

其实现目的是使得信号强度(使用Variance度量)在神经网络训练的过程中保持不变

kaiming_normal_/kaiming_uniform_

1
2
3
4
# kaiming正态分布
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
# kaiming均匀分布
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

这两个方法来自于论文Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

Xavier初始化方法假设网络中没有激活函数,而激活函数会改变神经网络中流动数据的分布,Kaiming Initialization正是为了解决这个问题而提出的

zeros_/ones_/constant_

1
2
3
4
5
6
# 全零初始化
torch.nn.init.zeros_(tensor, val)
# 全一初始化
torch.nn.init.ones_(tensor, val)
# 常量初始化
torch.nn.init.constant_(tensor, val)

常数初始化。示例如下

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> w = torch.empty(2,3)
# 全1填充
>>> torch.nn.init.ones_(w)
tensor([[1., 1., 1.],
[1., 1., 1.]])
# 全0填充
>>> torch.nn.init.zeros_(w)
tensor([[0., 0., 0.],
[0., 0., 0.]])
# 常量填充
>>> torch.nn.init.constant_(w, 0.3)
tensor([[0.3000, 0.3000, 0.3000],
[0.3000, 0.3000, 0.3000]])

torch.nn.init.normal_

torch.nn.init.normal_(tensor, mean=0.0, std=1.0)

正态分布初始化,保证输入张量符合均值为mean,方差为std**2的正态分布\(N(mean, std^{2})\)

1
2
3
4
>>> w = torch.empty(2,3)
>>> torch.nn.init.normal_(w)
tensor([[-0.2951, 0.2850, 0.7402],
[-0.2483, -2.0617, -1.0293]])

小结

小结PyTorch的权重初始化使用,定义一个初始化函数如下:

1
2
3
4
5
6
7
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

Note:如果卷积操作后面没有使用激活函数,可以使用xavier_normal_

相关阅读