for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
torchvision.models.squeezenet实现的权重初始化
1 2 3 4 5 6 7 8
for m in self.modules(): if isinstance(m, nn.Conv2d): if m is final_conv: init.normal_(m.weight, mean=0.0, std=0.01) else: init.kaiming_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0)
torchvision.models.googlenet实现的权重初始化
1 2 3 4 5 6 7 8 9 10 11 12
def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats X = stats.truncnorm(-2, 2, scale=0.01) values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) values = values.view(m.weight.size()) with torch.no_grad(): m.weight.copy_(values) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
torchvision.models.mobilenet实现的权重初始化
1 2 3 4 5 6 7 8 9 10 11 12
# weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias)
def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)